Skip to content

Commit 7a7d967

Browse files
committed
fix eager run for cuda
ghstack-source-id: 35f8e34 Pull Request resolved: #6365
1 parent 7493aae commit 7a7d967

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

examples/models/llama/runner/eager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ def __init__(self, args):
3333
use_kv_cache=args.use_kv_cache,
3434
**params,
3535
)
36-
super().__init__(tokenizer_path=args.tokenizer_path, model_args=model_args)
37-
manager: LLMEdgeManager = _prepare_for_llama_export("llama", args)
38-
self.model = (
39-
manager.model.eval().to(device="cuda")
40-
if torch.cuda.is_available()
41-
else manager.model.eval().to(device="cpu")
36+
super().__init__(
37+
tokenizer_path=args.tokenizer_path,
38+
model_args=model_args,
39+
device="cuda" if torch.cuda.is_available() else "cpu",
4240
)
41+
manager: LLMEdgeManager = _prepare_for_llama_export("llama", args)
42+
self.model = manager.model.eval().to(device=self.device)
4343

4444
def forward(
4545
self,

examples/models/llama/runner/generation.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
5151

5252

5353
class LlamaRunner(ABC):
54-
def __init__(self, tokenizer_path: str, model_args: ModelArgs):
54+
def __init__(self, tokenizer_path: str, model_args: ModelArgs, device: str = "cpu"):
5555
self.params = model_args
5656
self.tokenizer = get_tokenizer(tokenizer_path)
5757
assert model_args.vocab_size == self.tokenizer.n_words
58+
self.device = device
5859

5960
@abstractmethod
6061
def forward(
@@ -73,9 +74,9 @@ def generate( # noqa: C901
7374
) -> List[int]:
7475
# prefill
7576
logits = self.forward(
76-
tokens=torch.tensor([prompt_tokens], dtype=torch.long),
77+
tokens=torch.tensor([prompt_tokens], dtype=torch.long).to(self.device),
7778
input_pos=(
78-
torch.tensor([0], dtype=torch.long)
79+
torch.tensor([0], dtype=torch.long).to(self.device)
7980
if self.params.use_kv_cache
8081
else None
8182
),
@@ -87,11 +88,17 @@ def generate( # noqa: C901
8788
while len(tokens) < self.params.max_seq_len:
8889
if self.params.use_kv_cache:
8990
logits = self.forward(
90-
tokens=torch.tensor([[current_token]], dtype=torch.long),
91-
input_pos=torch.tensor([len(tokens) - 1], dtype=torch.long),
91+
tokens=torch.tensor([[current_token]], dtype=torch.long).to(
92+
self.device
93+
),
94+
input_pos=torch.tensor([len(tokens) - 1], dtype=torch.long).to(
95+
self.device
96+
),
9297
)
9398
else:
94-
logits = self.forward(tokens=torch.tensor([tokens], dtype=torch.long))
99+
logits = self.forward(
100+
tokens=torch.tensor([tokens], dtype=torch.long).to(self.device)
101+
)
95102
current_token = next_token(logits, temperature, top_p)
96103
if current_token == self.tokenizer.eos_id or (
97104
hasattr(self, "stop_tokens") and current_token in self.stop_tokens

0 commit comments

Comments
 (0)