diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 94dbb5a0651..5734cd66ef7 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -5,7 +5,7 @@ @dataclass class ModelArgs: dim: int = 4096 - n_layers: int = 8 + n_layers: int = 1 n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = 512 # Arbitrary value, should be defined later by tokenizer. diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index f4f51a921d1..b94adb5fa0c 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -19,7 +19,6 @@ class ExportLlamaLibTest(unittest.TestCase): - @unittest.skip("Keeps failing on trunk, temporarily skip") def test_has_expected_ops_and_op_counts(self): """ Checks the presence of unwanted expensive ops.