Skip to content

Commit 77ad854

Browse files
authored
model-conversion : cast logits to float32 (#18009)
1 parent 609a2d0 commit 77ad854

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/model-conversion/scripts/causal/run-org-model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def fn(_m, input, output):
200200
logits = outputs.logits
201201

202202
# Extract logits for the last token (next token prediction)
203-
last_logits = logits[0, -1, :].cpu().numpy()
203+
last_logits = logits[0, -1, :].float().cpu().numpy()
204204

205205
print(f"Logits shape: {logits.shape}")
206206
print(f"Last token logits shape: {last_logits.shape}")

0 commit comments

Comments
 (0)