2323from executorch .examples .models .llama .runner .generation import LlamaRunner
2424
2525# Note: import this after portable_lib
26- # from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
26+ from executorch .extension .llm .custom_ops import sdpa_with_kv_cache # noqa # usort: skip
2727from executorch .kernels import quantized # noqa
2828
2929
@@ -50,17 +50,11 @@ def forward(
5050 tokens : torch .Tensor ,
5151 input_pos : Optional [torch .Tensor ] = None ,
5252 ) -> torch .Tensor :
53- # TODO: in LlamaRunner there is a generate function that automatically generates
54- # input_pos tensor and inputs it into the model. Atm TorchTune models use
55- # kwargs for the input_pos, so we will need to make some changes. At least
56- # for the time being, we can run the non-kv cache version of the Torchtune
57- # model with just the tokens like below.
58- return (self .model .forward ((tokens ,)))[0 ]
59- # return (
60- # self.model.forward((tokens, input_pos))
61- # if input_pos is not None
62- # else self.model.forward((tokens,))
63- # )[0]
53+ return (
54+ self .model .forward ((tokens , input_pos ))
55+ if input_pos is not None
56+ else self .model .forward ((tokens ,))
57+ )[0 ]
6458
6559
6660def build_args_parser () -> argparse .ArgumentParser :
@@ -69,7 +63,7 @@ def build_args_parser() -> argparse.ArgumentParser:
6963
7064 parser .add_argument (
7165 "--model" ,
72- default = "llama " ,
66+ default = "llama3 " ,
7367 choices = EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS ,
7468 )
7569
0 commit comments