Skip to content

Commit 84422d9

Browse files
committed
Fixes
1 parent e1ec74c commit 84422d9

File tree

2 files changed

+8
-14
lines changed

2 files changed

+8
-14
lines changed

examples/models/llama/runner/native.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from executorch.examples.models.llama.runner.generation import LlamaRunner
2424

2525
# Note: import this after portable_lib
26-
# from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
26+
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
2727
from executorch.kernels import quantized # noqa
2828

2929

@@ -50,17 +50,11 @@ def forward(
5050
tokens: torch.Tensor,
5151
input_pos: Optional[torch.Tensor] = None,
5252
) -> torch.Tensor:
53-
# TODO: in LlamaRunner there is a generate function that automatically generates
54-
# input_pos tensor and inputs it into the model. Atm TorchTune models use
55-
# kwargs for the input_pos, so we will need to make some changes. At least
56-
# for the time being, we can run the non-kv cache version of the Torchtune
57-
# model with just the tokens like below.
58-
return (self.model.forward((tokens,)))[0]
59-
# return (
60-
# self.model.forward((tokens, input_pos))
61-
# if input_pos is not None
62-
# else self.model.forward((tokens,))
63-
# )[0]
53+
return (
54+
self.model.forward((tokens, input_pos))
55+
if input_pos is not None
56+
else self.model.forward((tokens,))
57+
)[0]
6458

6559

6660
def build_args_parser() -> argparse.ArgumentParser:
@@ -69,7 +63,7 @@ def build_args_parser() -> argparse.ArgumentParser:
6963

7064
parser.add_argument(
7165
"--model",
72-
default="llama",
66+
default="llama3",
7367
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
7468
)
7569

examples/models/llama3_2_vision/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(self, **kwargs):
111111
# Load checkpoint.
112112
missing, unexpected = self.model_.load_state_dict(
113113
checkpoint,
114-
strict=True,
114+
strict=False,
115115
assign=True,
116116
)
117117
if kwargs.get("verbose", False):

0 commit comments

Comments
 (0)