Skip to content

Commit 8ea325b

Browse files
bdu91LIT team
authored andcommitted
Add batch_size flag to lm_salience_demo and fix minor bugs.
PiperOrigin-RevId: 623274231
1 parent 8ca1312 commit 8ea325b

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

lit_nlp/examples/lm_salience_demo.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
To run with the default configuration (Gemma on TensorFlow via Keras):
99
1010
blaze run -c opt examples:lm_salience_demo -- \
11-
--models=gemma_instruct_2b_en:gemma_instruct_2b_en \
11+
--models=gemma_1.1_instruct_2b_en:gemma_1.1_instruct_2b_en \
1212
--port=8890 --alsologtostderr
1313
1414
MODELS:
@@ -64,7 +64,7 @@
6464

6565
_MODELS = flags.DEFINE_list(
6666
"models",
67-
["gemma_instruct_2b_en:gemma_instruct_2b_en"],
67+
["gemma_1.1_instruct_2b_en:gemma_1.1_instruct_2b_en"],
6868
"Models to load, as <name>:<path>. Path can be a URL, a local file path, or"
6969
" the name of a preset for the configured Deep Learning framework (either"
7070
" KerasNLP or HuggingFace Transformers; see --dl_framework for more). This"
@@ -91,6 +91,10 @@
9191
),
9292
)
9393

94+
_BATCH_SIZE = flags.DEFINE_integer(
95+
"batch_size", 4, "The number of examples to process per batch.",
96+
)
97+
9498
_DL_BACKEND = flags.DEFINE_enum(
9599
"dl_backend",
96100
"tensorflow",
@@ -278,18 +282,17 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
278282
path = file_cache.cached_path(
279283
path,
280284
extract_compressed_file=path.endswith(".tar.gz"),
281-
copy_directories=True,
282285
)
283286

284-
if _DL_FRAMEWORK.value == "keras":
287+
if _DL_FRAMEWORK.value == "kerasnlp":
285288
# pylint: disable=g-import-not-at-top
286289
from keras_nlp import models as keras_models
287290
from lit_nlp.examples.models import instrumented_keras_lms as lit_keras
288291
# pylint: enable=g-import-not-at-top
289292
# Load the weights once for the underlying Keras model.
290293
model = keras_models.CausalLM.from_preset(path)
291294
models |= lit_keras.initialize_model_group_for_salience(
292-
model_name, model, max_length=512, batch_size=4
295+
model_name, model, max_length=512, batch_size=_BATCH_SIZE.value
293296
)
294297
# Disable embeddings from the generation model.
295298
# TODO(lit-dev): re-enable embeddings if we can figure out why UMAP was
@@ -301,7 +304,11 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
301304
# Assuming a valid decoder model name supported by
302305
# `transformers.AutoModelForCausalLM` is provided to "path".
303306
models |= pretrained_lms.initialize_model_group_for_salience(
304-
model_name, path, framework=_DL_BACKEND.value, max_new_tokens=512
307+
model_name,
308+
path,
309+
batch_size=_BATCH_SIZE.value,
310+
framework=_DL_BACKEND.value,
311+
max_new_tokens=512,
305312
)
306313

307314
for name in datasets:

0 commit comments

Comments
 (0)