Skip to content

Commit 2d5dec7

Browse files
committed
run granite speech with bfloat16.
increase batch sizes to improve gpu utilization
1 parent 76e5444 commit 2d5dec7

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

granite/run_eval.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def main(args):
1717
processor = AutoProcessor.from_pretrained(args.model_id)
1818
tokenizer = processor.tokenizer
19-
model = AutoModelForSpeechSeq2Seq.from_pretrained(args.model_id).to(args.device)
19+
model = AutoModelForSpeechSeq2Seq.from_pretrained(args.model_id, torch_dtype=torch.bfloat16).to(args.device)
2020

2121
# create text prompt
2222
chat = [
@@ -45,24 +45,24 @@ def benchmark(batch, min_new_tokens=None):
4545
# START TIMING
4646
start_time = time.time()
4747

48-
with torch.autocast(model.device.type, enabled=True):
49-
model_inputs = processor(
50-
texts,
51-
audios,
52-
device=args.device, # Computation device; returned tensors are put on CPU
53-
return_tensors="pt",
54-
).to(args.device)
55-
56-
# Model Inference
57-
model_outputs = model.generate(
58-
**model_inputs,
59-
bos_token_id=tokenizer.bos_token_id,
60-
pad_token_id=tokenizer.pad_token_id,
61-
eos_token_id=tokenizer.eos_token_id,
62-
repetition_penalty=1.0,
63-
**gen_kwargs,
64-
min_new_tokens=min_new_tokens,
65-
)
48+
# with torch.autocast(model.device.type, enabled=True):
49+
model_inputs = processor(
50+
texts,
51+
audios,
52+
device=args.device, # Computation device; returned tensors are put on CPU
53+
return_tensors="pt",
54+
).to(args.device)
55+
56+
# Model Inference
57+
model_outputs = model.generate(
58+
**model_inputs,
59+
bos_token_id=tokenizer.bos_token_id,
60+
pad_token_id=tokenizer.pad_token_id,
61+
eos_token_id=tokenizer.eos_token_id,
62+
repetition_penalty=1.0,
63+
**gen_kwargs,
64+
min_new_tokens=min_new_tokens,
65+
)
6666

6767
# Transformers includes the input IDs in the response.
6868
num_input_tokens = model_inputs["input_ids"].shape[-1]

granite/run_granite.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ MODEL_IDs=(
88
)
99

1010
BATCH_SIZEs=(
11-
20
12-
12
11+
160
12+
64
1313
)
1414

1515
NUM_BEAMS=1

0 commit comments

Comments
 (0)