Skip to content

Commit 523c249

Browse files
author
sanchit-gandhi
committed
updates
1 parent 8e47917 commit 523c249

File tree

5 files changed

+96
-77
lines changed

5 files changed

+96
-77
lines changed

transformers/run_eval.py

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33

44
import torch
5+
from torch.nn.attention import sdpa_kernel, SDPBackend
56
from transformers import AutoConfig, AutoModelForSpeechSeq2Seq, AutoModelForCTC, AutoProcessor, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
67
import evaluate
78
from normalizer import data_utils
@@ -10,11 +11,14 @@
1011

1112
wer_metric = evaluate.load("wer")
1213

14+
torch.set_float32_matmul_precision('high')
15+
torch._logging.set_logs(graph_breaks=True, recompiles=True)
16+
1317

1418
def main(args):
1519
config = AutoConfig.from_pretrained(args.model_id)
1620
cls_model = AutoModelForSpeechSeq2Seq if type(config) in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING else AutoModelForCTC
17-
model = cls_model.from_pretrained(args.model_id, torch_dtype=torch.float16).to(args.device)
21+
model = cls_model.from_pretrained(args.model_id, torch_dtype=torch.bfloat16, attn_implementation="sdpa").to(args.device)
1822
processor = AutoProcessor.from_pretrained(args.model_id)
1923
model_input_name = processor.model_input_names[0]
2024

@@ -25,13 +29,11 @@ def main(args):
2529
gen_kwargs["language"] = "en"
2630
gen_kwargs["task"] = "transcribe"
2731

28-
dataset = data_utils.load_data(args)
29-
30-
if args.max_eval_samples is not None and args.max_eval_samples > 0:
31-
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
32-
dataset = dataset.take(args.max_eval_samples)
33-
34-
dataset = data_utils.prepare_data(dataset)
32+
if args.torch_compile:
33+
model.forward = torch.compile(model.forward, mode=args.compile_mode, fullgraph=True)
34+
if model.can_generate():
35+
# enable static k/v cache for autoregressive models
36+
model.generation_config.cache_implementation = "static"
3537

3638
def benchmark(batch):
3739
# Load audio inputs
@@ -42,8 +44,15 @@ def benchmark(batch):
4244
start_time = time.time()
4345

4446
# 1. Pre-Processing
45-
if not model.can_generate() or len(audios[0]) > processor.feature_extractor.n_samples:
46-
# 1.1 Either CTC pre-processing (normalize to mean 0, std 1), or long-form Whisper processing
47+
# 1.1 Pad audios to max batch size if using torch compile to prevent re-compilations
48+
padding_size = None
49+
if minibatch_size != args.batch_size and args.torch_compile:
50+
padding_size = args.batch_size - minibatch_size
51+
padding_audios = [audios[-1] for _ in range(padding_size)]
52+
audios.extend(padding_audios)
53+
54+
if not model.can_generate(): #or len(audios[0]) > processor.feature_extractor.n_samples:
55+
# 1.2 Either CTC pre-processing (normalize to mean 0, std 1), or long-form Whisper processing
4756
inputs = processor(
4857
audios,
4958
sampling_rate=16_000,
@@ -53,23 +62,29 @@ def benchmark(batch):
5362
return_attention_mask=True,
5463
)
5564
else:
56-
# 1.2 Standard Whisper processing: pad audios to 30-seconds and converted to log-mel
57-
inputs = processor(audios, sampling_rate=16_000, return_tensors="pt")
65+
# 1.3 Standard Whisper processing: pad audios to 30-seconds and converted to log-mel
66+
inputs = processor(audios, sampling_rate=16_000, return_tensors="pt", device=args.device)
5867

5968
inputs = inputs.to(args.device)
60-
inputs[model_input_name] = inputs[model_input_name].to(torch.float16)
69+
inputs[model_input_name] = inputs[model_input_name].to(torch.bfloat16)
6170

6271
# 2. Model Inference
63-
if model.can_generate():
64-
# 2.1 Auto-regressive generation for encoder-decoder models
65-
pred_ids = model.generate(**inputs, **gen_kwargs)
66-
else:
67-
# 2.2. Single forward pass for CTC
68-
with torch.no_grad():
69-
logits = model(**inputs)
70-
pred_ids = logits.argmax(-1)
71-
72-
# 3. Post-processing: convert token ids to text transcription
72+
with sdpa_kernel(SDPBackend.MATH if args.torch_compile else SDPBackend.FLASH_ATTENTION):
73+
if model.can_generate():
74+
# 2.1 Auto-regressive generation for encoder-decoder models
75+
pred_ids = model.generate(**inputs, **gen_kwargs)
76+
else:
77+
# 2.2. Single forward pass for CTC
78+
with torch.no_grad():
79+
logits = model(**inputs).logits
80+
pred_ids = logits.argmax(-1)
81+
82+
# 3. Post-processing
83+
# 3.1 Strip padded ids from predictions
84+
if padding_size is not None:
85+
pred_ids = pred_ids[:-padding_size, ...]
86+
87+
# 3.2 Convert token ids to text transcription
7388
pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True)
7489

7590
# END TIMING
@@ -83,8 +98,31 @@ def benchmark(batch):
8398
batch["references"] = batch["norm_text"]
8499
return batch
85100

101+
if args.warmup_steps is not None:
102+
dataset = data_utils.load_data(args)
103+
dataset = data_utils.prepare_data(dataset)
104+
105+
num_warmup_samples = args.warmup_steps * args.batch_size
106+
if args.streaming:
107+
warmup_dataset = dataset.take(num_warmup_samples)
108+
else:
109+
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
110+
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True))
111+
112+
for _ in tqdm(warmup_dataset, desc="Warming up..."):
113+
continue
114+
115+
dataset = data_utils.load_data(args)
116+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
117+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
118+
if args.streaming:
119+
dataset = dataset.take(args.max_eval_samples)
120+
else:
121+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
122+
dataset = data_utils.prepare_data(dataset)
123+
86124
dataset = dataset.map(
87-
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"]
125+
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
88126
)
89127

90128
all_results = {
@@ -94,7 +132,7 @@ def benchmark(batch):
94132
"references": [],
95133
}
96134
result_iter = iter(dataset)
97-
for result in tqdm(result_iter, desc="Samples"):
135+
for result in tqdm(result_iter, desc="Samples..."):
98136
for key in all_results:
99137
all_results[key].append(result[key])
100138

@@ -171,6 +209,23 @@ def benchmark(batch):
171209
action="store_false",
172210
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
173211
)
212+
parser.add_argument(
213+
"--torch_compile",
214+
action="store_true",
215+
help="Whether to JIT compile the forward pass of the model.",
216+
)
217+
parser.add_argument(
218+
"--compile_mode",
219+
type=str,
220+
default="max-autotune",
221+
help="Mode for torch compiling model forward pass. Can be either 'default', 'reduce-overhead', 'max-autotune' or 'max-autotune-no-cudagraphs'.",
222+
)
223+
parser.add_argument(
224+
"--warmup_steps",
225+
type=int,
226+
default=10,
227+
help="Number of warm-up steps to run before launching the timed runs.",
228+
)
174229
args = parser.parse_args()
175230
parser.set_defaults(streaming=False)
176231

transformers/run_hubert.sh

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,6 @@ do
8484
--batch_size=${BATCH_SIZE} \
8585
--max_eval_samples=-1
8686

87-
python run_eval.py \
88-
--model_id=${MODEL_ID} \
89-
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
90-
--dataset="common_voice" \
91-
--split="test" \
92-
--device=0 \
93-
--batch_size=${BATCH_SIZE} \
94-
--max_eval_samples=-1
95-
9687
# Evaluate results
9788
RUNDIR=`pwd` && \
9889
cd ../normalizer && \

transformers/run_mms.sh

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,6 @@ do
8484
--batch_size=${BATCH_SIZE} \
8585
--max_eval_samples=-1
8686

87-
python run_eval.py \
88-
--model_id=${MODEL_ID} \
89-
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
90-
--dataset="common_voice" \
91-
--split="test" \
92-
--device=0 \
93-
--batch_size=${BATCH_SIZE} \
94-
--max_eval_samples=-1
95-
9687
# Evaluate results
9788
RUNDIR=`pwd` && \
9889
cd ../normalizer && \

transformers/run_wav2vec2.sh

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ do
1616
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
1717
--dataset="ami" \
1818
--split="test" \
19-
--device=1 \
19+
--device=0 \
2020
--batch_size=${BATCH_SIZE} \
2121
--max_eval_samples=-1
2222

@@ -26,7 +26,7 @@ do
2626
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
2727
--dataset="earnings22" \
2828
--split="test" \
29-
--device=1 \
29+
--device=0 \
3030
--batch_size=${BATCH_SIZE} \
3131
--max_eval_samples=-1
3232

@@ -35,7 +35,7 @@ do
3535
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
3636
--dataset="gigaspeech" \
3737
--split="test" \
38-
--device=1 \
38+
--device=0 \
3939
--batch_size=${BATCH_SIZE} \
4040
--max_eval_samples=-1
4141

@@ -44,7 +44,7 @@ do
4444
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
4545
--dataset="librispeech" \
4646
--split="test.clean" \
47-
--device=1 \
47+
--device=0 \
4848
--batch_size=${BATCH_SIZE} \
4949
--max_eval_samples=-1
5050

@@ -53,7 +53,7 @@ do
5353
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
5454
--dataset="librispeech" \
5555
--split="test.other" \
56-
--device=1 \
56+
--device=0 \
5757
--batch_size=${BATCH_SIZE} \
5858
--max_eval_samples=-1
5959

@@ -62,7 +62,7 @@ do
6262
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
6363
--dataset="spgispeech" \
6464
--split="test" \
65-
--device=1 \
65+
--device=0 \
6666
--batch_size=${BATCH_SIZE} \
6767
--max_eval_samples=-1
6868

@@ -71,7 +71,7 @@ do
7171
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
7272
--dataset="tedlium" \
7373
--split="test" \
74-
--device=1 \
74+
--device=0 \
7575
--batch_size=${BATCH_SIZE} \
7676
--max_eval_samples=-1
7777

@@ -80,16 +80,7 @@ do
8080
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
8181
--dataset="voxpopuli" \
8282
--split="test" \
83-
--device=1 \
84-
--batch_size=${BATCH_SIZE} \
85-
--max_eval_samples=-1
86-
87-
python run_eval.py \
88-
--model_id=${MODEL_ID} \
89-
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
90-
--dataset="common_voice" \
91-
--split="test" \
92-
--device=1 \
83+
--device=0 \
9384
--batch_size=${BATCH_SIZE} \
9485
--max_eval_samples=-1
9586

transformers/run_whisper.sh

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ do
1616
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
1717
--dataset="ami" \
1818
--split="test" \
19-
--device=2 \
19+
--device=0 \
2020
--batch_size=${BATCH_SIZE} \
2121
--max_eval_samples=-1
2222

@@ -25,7 +25,7 @@ do
2525
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
2626
--dataset="earnings22" \
2727
--split="test" \
28-
--device=2 \
28+
--device=0 \
2929
--batch_size=${BATCH_SIZE} \
3030
--max_eval_samples=-1
3131

@@ -34,7 +34,7 @@ do
3434
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
3535
--dataset="gigaspeech" \
3636
--split="test" \
37-
--device=2 \
37+
--device=0 \
3838
--batch_size=${BATCH_SIZE} \
3939
--max_eval_samples=-1
4040

@@ -43,7 +43,7 @@ do
4343
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
4444
--dataset="librispeech" \
4545
--split="test.clean" \
46-
--device=2 \
46+
--device=0 \
4747
--batch_size=${BATCH_SIZE} \
4848
--max_eval_samples=-1
4949

@@ -52,7 +52,7 @@ do
5252
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
5353
--dataset="librispeech" \
5454
--split="test.other" \
55-
--device=2 \
55+
--device=0 \
5656
--batch_size=${BATCH_SIZE} \
5757
--max_eval_samples=-1
5858

@@ -61,7 +61,7 @@ do
6161
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
6262
--dataset="spgispeech" \
6363
--split="test" \
64-
--device=2 \
64+
--device=0 \
6565
--batch_size=${BATCH_SIZE} \
6666
--max_eval_samples=-1
6767

@@ -70,7 +70,7 @@ do
7070
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
7171
--dataset="tedlium" \
7272
--split="test" \
73-
--device=2 \
73+
--device=0 \
7474
--batch_size=${BATCH_SIZE} \
7575
--max_eval_samples=-1
7676

@@ -79,16 +79,7 @@ do
7979
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
8080
--dataset="voxpopuli" \
8181
--split="test" \
82-
--device=2 \
83-
--batch_size=${BATCH_SIZE} \
84-
--max_eval_samples=-1
85-
86-
python run_eval.py \
87-
--model_id=${MODEL_ID} \
88-
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
89-
--dataset="common_voice" \
90-
--split="test" \
91-
--device=2 \
82+
--device=0 \
9283
--batch_size=${BATCH_SIZE} \
9384
--max_eval_samples=-1
9485

0 commit comments

Comments
 (0)