Skip to content

Commit 686cc66

Browse files
authored
fix(server): Check for device type correctly when determining initial padding (IBM#16)
AFAIK there is no torch device type called "gpu".
1 parent 611e21c commit 686cc66

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

server/text_generation/models/causal_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def from_pb(
6565
)
6666
all_logprobs.append(None)
6767

68-
pad_to_multiple_of = 8 if "gpu" in str(device) else None
68+
pad_to_multiple_of = 8 if device.type == "cuda" else None
6969
tokenized_inputs = tokenizer(
7070
inputs,
7171
return_tensors="pt",

server/text_generation/models/seq2seq_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def from_pb(
7777
decoder_logprobs.append(None)
7878

7979
# Tokenize batch
80-
pad_to_multiple_of = 8 if "gpu" in str(device) else None
80+
pad_to_multiple_of = 8 if device.type == "cuda" else None
8181
tokenized_inputs = tokenizer(
8282
inputs,
8383
return_tensors="pt",

0 commit comments

Comments
 (0)