Skip to content

Commit f4060c0

Browse files
tjohnson31415njhill
authored andcommitted
fix: use torch instead of numpy to resolve device mismatch bug
Signed-off-by: Travis Johnson <[email protected]>
1 parent 63142fc commit f4060c0

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

server/text_generation_server/models/causal_lm.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import time
33
from operator import itemgetter
44

5-
import numpy as np
65
import torch
76

87
from dataclasses import dataclass
@@ -163,15 +162,11 @@ def from_pb(
163162

164163
# Padded all_input_ids_tensor; the maximum length of any sequence is the max
165164
# (padded) input sequence length + the max output length
166-
all_input_ids_tensor = np.full(
165+
all_input_ids_tensor = torch.full(
167166
(batch_size, tokenize_length + padding_right_offset),
168167
tokenizer.pad_token_id,
169168
)
170169
all_input_ids_tensor[:, :all_input_ids.shape[1]] = all_input_ids
171-
# Create tensors on device
172-
all_input_ids_tensor = all_input_ids.new_tensor(
173-
all_input_ids_tensor,
174-
)
175170

176171
if prefix_ids:
177172
# Get input embeddings

0 commit comments

Comments
 (0)