Skip to content

Commit f504cc5

Browse files
committed
Merge branch 'jz/native-runner-tt' into jz/tt-llama-kv-cache
2 parents a36703e + 2b9f281 commit f504cc5

File tree

3 files changed

+4
-0
lines changed

3 files changed

+4
-0
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"resnet50": "linux.12xlarge",
2626
"llava": "linux.12xlarge",
2727
"llama3_2_vision_encoder": "linux.12xlarge",
28+
"llama3_2_text_decoder": "linux.12xlarge",
2829
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
2930
"dl3": "linux.12xlarge",
3031
"emformer_join": "linux.12xlarge",

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"llama2": ("llama", "Llama2Model"),
2020
"llama": ("llama", "Llama2Model"),
2121
"llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"),
22+
"llama3_2_text_decoder": ("llama3_2_vision", "Llama3_2Decoder"),
2223
"lstm": ("lstm", "LSTMModel"),
2324
"mobilebert": ("mobilebert", "MobileBertModelExample"),
2425
"mv2": ("mobilenet_v2", "MV2Model"),

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,8 @@ def _load_llama_model(
888888
if modelname == "llama3_2_vision":
889889
module_name = "llama3_2_vision"
890890
model_class_name = "Llama3_2Decoder"
891+
else:
892+
raise ValueError(f"{modelname} is not a valid Llama model.")
891893
else:
892894
raise ValueError(f"{modelname} is not a valid Llama model.")
893895

0 commit comments

Comments
 (0)