88To run with the default configuration (Gemma on TensorFlow via Keras):
99
1010 blaze run -c opt examples:lm_salience_demo -- \
11- --models=gemma_instruct_2b_en:gemma_instruct_2b_en \
11+ --models=gemma_1.1_instruct_2b_en:gemma_1.1_instruct_2b_en \
1212 --port=8890 --alsologtostderr
1313
1414MODELS:
6464
6565_MODELS = flags .DEFINE_list (
6666 "models" ,
67- ["gemma_instruct_2b_en:gemma_instruct_2b_en " ],
67+ ["gemma_1.1_instruct_2b_en:gemma_1.1_instruct_2b_en " ],
6868 "Models to load, as <name>:<path>. Path can be a URL, a local file path, or"
6969 " the name of a preset for the configured Deep Learning framework (either"
7070 " KerasNLP or HuggingFace Transformers; see --dl_framework for more). This"
9191 ),
9292)
9393
94+ _BATCH_SIZE = flags .DEFINE_integer (
95+ "batch_size" , 4 , "The number of examples to process per batch." ,
96+ )
97+
9498_DL_BACKEND = flags .DEFINE_enum (
9599 "dl_backend" ,
96100 "tensorflow" ,
@@ -278,18 +282,17 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
278282 path = file_cache .cached_path (
279283 path ,
280284 extract_compressed_file = path .endswith (".tar.gz" ),
281- copy_directories = True ,
282285 )
283286
284- if _DL_FRAMEWORK .value == "keras " :
287+ if _DL_FRAMEWORK .value == "kerasnlp " :
285288 # pylint: disable=g-import-not-at-top
286289 from keras_nlp import models as keras_models
287290 from lit_nlp .examples .models import instrumented_keras_lms as lit_keras
288291 # pylint: enable=g-import-not-at-top
289292 # Load the weights once for the underlying Keras model.
290293 model = keras_models .CausalLM .from_preset (path )
291294 models |= lit_keras .initialize_model_group_for_salience (
292- model_name , model , max_length = 512 , batch_size = 4
295+ model_name , model , max_length = 512 , batch_size = _BATCH_SIZE . value
293296 )
294297 # Disable embeddings from the generation model.
295298 # TODO(lit-dev): re-enable embeddings if we can figure out why UMAP was
@@ -301,7 +304,11 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
301304 # Assuming a valid decoder model name supported by
302305 # `transformers.AutoModelForCausalLM` is provided to "path".
303306 models |= pretrained_lms .initialize_model_group_for_salience (
304- model_name , path , framework = _DL_BACKEND .value , max_new_tokens = 512
307+ model_name ,
308+ path ,
309+ batch_size = _BATCH_SIZE .value ,
310+ framework = _DL_BACKEND .value ,
311+ max_new_tokens = 512 ,
305312 )
306313
307314 for name in datasets :
0 commit comments