3333class TestLlama (unittest .TestCase ):
3434 """
3535 Test class of Llama models. Type of Llama model depends on command line parameters:
36- --llama_inputs <path to .pt file> <path to json file>
37- Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json
36+ --llama_inputs <path to .pt file> <path to json file> <name of model variant>
37+ Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json stories110m
38+ For more examples and info see examples/models/llama/README.md.
3839 """
3940
4041 def prepare_model (self ):
4142
4243 checkpoint = None
4344 params_file = None
45+ usage = "To run use --llama_inputs <.pt/.pth> <.json> <name>"
46+
4447 if conftest .is_option_enabled ("llama_inputs" ):
4548 param_list = conftest .get_option ("llama_inputs" )
46- assert (
47- isinstance (param_list , list ) and len (param_list ) == 2
48- ), "invalid number of inputs for --llama_inputs"
49+
50+ if not isinstance (param_list , list ) or len (param_list ) != 3 :
51+ raise RuntimeError (
52+ f"Invalid number of inputs for --llama_inputs. { usage } "
53+ )
54+ if not all (isinstance (param , str ) for param in param_list ):
55+ raise RuntimeError (
56+ f"All --llama_inputs are expected to be strings. { usage } "
57+ )
58+
4959 checkpoint = param_list [0 ]
5060 params_file = param_list [1 ]
51- assert isinstance (checkpoint , str ) and isinstance (
52- params_file , str
53- ), "invalid input for --llama_inputs"
61+ model_name = param_list [2 ]
5462 else :
5563 logger .warning (
56- "Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json> "
64+ "Skipping Llama tests because of missing --llama_inputs. {usage} "
5765 )
5866 return None , None , None
5967
@@ -71,7 +79,7 @@ def prepare_model(self):
7179 "-p" ,
7280 params_file ,
7381 "--model" ,
74- "stories110m" ,
82+ model_name ,
7583 ]
7684 parser = build_args_parser ()
7785 args = parser .parse_args (args )
@@ -122,6 +130,7 @@ def test_llama_tosa_BI(self):
122130 .quantize ()
123131 .export ()
124132 .to_edge_transform_and_lower ()
133+ .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
125134 .to_executorch ()
126135 .run_method_and_compare_outputs (
127136 inputs = llama_inputs ,
0 commit comments