Skip to content

Commit 9bd405f

Browse files
committed
Lint
1 parent c80ce1c commit 9bd405f

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

examples/models/llama3_2_vision/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def get_dynamic_shapes(self):
175175
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
176176
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
177177
"mask": {0: batch_size, 1: dim_seq_len, 2: None},
178-
"input_pos" : {0: batch_size, 1: dim_seq_len},
178+
"input_pos": {0: batch_size, 1: dim_seq_len},
179179
}
180180
else:
181181
dynamic_shapes = {

examples/models/llama3_2_vision/runner/eager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from executorch.examples.models.llama.export_llama_lib import (
1414
_prepare_for_llama_export,
1515
build_args_parser as _build_args_parser,
16-
TORCHTUNE_DEFINED_MODELS,
1716
)
18-
from executorch.examples.models.llama3_2_vision.runner.generation import TorchTuneLlamaRunner
17+
from executorch.examples.models.llama3_2_vision.runner.generation import (
18+
TorchTuneLlamaRunner,
19+
)
1920
from executorch.extension.llm.export import LLMEdgeManager
2021

2122

examples/models/llama3_2_vision/runner/generation.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from abc import ABC, abstractmethod
8-
from typing import List, Optional, TypedDict
7+
from typing import List
98

109
import torch
11-
12-
from executorch.extension.llm.tokenizer.utils import get_tokenizer
13-
from executorch.examples.models.llama.runner.generation import LlamaRunner, next_token, sample_top_p
10+
from executorch.examples.models.llama.runner.generation import LlamaRunner, next_token
1411

1512

1613
class TorchTuneLlamaRunner(LlamaRunner):
@@ -54,13 +51,17 @@ def generate( # noqa: C901
5451
mask = self.causal_mask[None, :seq_len]
5552
if self.use_kv_cache:
5653
logits = self.forward(
57-
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
54+
tokens=torch.tensor(
55+
[prompt_tokens], dtype=torch.long, device=self.device
56+
),
5857
input_pos=input_pos,
5958
mask=mask,
6059
)
6160
else:
6261
logits = self.forward(
63-
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
62+
tokens=torch.tensor(
63+
[prompt_tokens], dtype=torch.long, device=self.device
64+
),
6465
)
6566

6667
# Only need the last logit.
@@ -98,4 +99,3 @@ def generate( # noqa: C901
9899
seq_len += 1
99100

100101
return tokens if echo else tokens[len(prompt_tokens) :]
101-

0 commit comments

Comments
 (0)