Skip to content

Commit 042180d

Browse files
fix(server): Only pad to multiple of 8 on GPUs
1 parent a298503 commit 042180d

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

server/text_generation/models/causal_lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ def from_pb(
7171
)
7272
)
7373

74+
pad_to_multiple_of = 8 if "gpu" in str(device) else None
7475
tokenized_inputs = tokenizer(
75-
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
76+
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
7677
).to(device)
7778
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
7879

server/text_generation/models/seq2seq_lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,9 @@ def from_pb(
8383
)
8484

8585
# Tokenize batch
86+
pad_to_multiple_of = 8 if "gpu" in str(device) else None
8687
tokenized_inputs = tokenizer(
87-
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
88+
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
8889
).to(device)
8990
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
9091
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)

0 commit comments

Comments
 (0)