Skip to content

Commit c464e0c

Browse files
Merge pull request #65 from gsaon/granite-speech-3.3-8b
Granite speech support
2 parents 51e1c71 + 80d876a commit c464e0c

File tree

3 files changed

+345
-0
lines changed

3 files changed

+345
-0
lines changed

granite/run_eval.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import argparse
2+
import os
3+
import torch
4+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, models
5+
import evaluate
6+
from normalizer import data_utils
7+
import time
8+
from tqdm import tqdm
9+
10+
# ensure installed transformers supports granite_speech
11+
assert hasattr(models, "granite_speech")
12+
13+
wer_metric = evaluate.load("wer")
14+
torch.set_float32_matmul_precision('high')
15+
16+
def main(args):
17+
processor = AutoProcessor.from_pretrained(args.model_id)
18+
tokenizer = processor.tokenizer
19+
model = AutoModelForSpeechSeq2Seq.from_pretrained(args.model_id).to(args.device)
20+
21+
# create text prompt
22+
chat = [
23+
{
24+
"role": "system",
25+
"content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant",
26+
},
27+
{
28+
"role": "user",
29+
"content": "<|audio|>can you transcribe the speech into a written format?",
30+
}
31+
]
32+
33+
text = tokenizer.apply_chat_template(
34+
chat, tokenize=False, add_generation_prompt=True
35+
)
36+
37+
gen_kwargs = {"max_new_tokens": args.max_new_tokens, "num_beams": args.num_beams}
38+
39+
def benchmark(batch, min_new_tokens=None):
40+
# Load audio inputs
41+
audios = [audio["array"] for audio in batch["audio"]]
42+
minibatch_size = len(audios)
43+
texts=[text] * minibatch_size
44+
45+
# START TIMING
46+
start_time = time.time()
47+
48+
with torch.autocast(model.device.type, enabled=True):
49+
model_inputs = processor(
50+
texts,
51+
audios,
52+
device=args.device, # Computation device; returned tensors are put on CPU
53+
return_tensors="pt",
54+
).to(args.device)
55+
56+
# Model Inference
57+
model_outputs = model.generate(
58+
**model_inputs,
59+
bos_token_id=tokenizer.bos_token_id,
60+
pad_token_id=tokenizer.pad_token_id,
61+
eos_token_id=tokenizer.eos_token_id,
62+
repetition_penalty=1.0,
63+
**gen_kwargs,
64+
min_new_tokens=min_new_tokens,
65+
)
66+
67+
# Transformers includes the input IDs in the response.
68+
num_input_tokens = model_inputs["input_ids"].shape[-1]
69+
new_tokens = model_outputs[:, num_input_tokens:]
70+
71+
output_text = tokenizer.batch_decode(
72+
new_tokens, add_special_tokens=False, skip_special_tokens=True
73+
)
74+
75+
# END TIMING
76+
runtime = time.time() - start_time
77+
78+
# normalize by minibatch size since we want the per-sample time
79+
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
80+
81+
# normalize transcriptions with English normalizer
82+
batch["predictions"] = [data_utils.normalizer(pred) for pred in output_text]
83+
batch["references"] = batch["norm_text"]
84+
return batch
85+
86+
if args.warmup_steps is not None:
87+
dataset = data_utils.load_data(args)
88+
dataset = data_utils.prepare_data(dataset)
89+
90+
num_warmup_samples = args.warmup_steps * args.batch_size
91+
if args.streaming:
92+
warmup_dataset = dataset.take(num_warmup_samples)
93+
else:
94+
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
95+
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens}))
96+
97+
for _ in tqdm(warmup_dataset, desc="Warming up..."):
98+
continue
99+
100+
dataset = data_utils.load_data(args)
101+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
102+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
103+
if args.streaming:
104+
dataset = dataset.take(args.max_eval_samples)
105+
else:
106+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
107+
dataset = data_utils.prepare_data(dataset)
108+
109+
dataset = dataset.map(
110+
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
111+
)
112+
113+
all_results = {
114+
"audio_length_s": [],
115+
"transcription_time_s": [],
116+
"predictions": [],
117+
"references": [],
118+
}
119+
result_iter = iter(dataset)
120+
for result in tqdm(result_iter, desc="Samples..."):
121+
for key in all_results:
122+
all_results[key].append(result[key])
123+
124+
# Write manifest results (WER and RTFX)
125+
manifest_path = data_utils.write_manifest(
126+
all_results["references"],
127+
all_results["predictions"],
128+
args.model_id,
129+
args.dataset_path,
130+
args.dataset,
131+
args.split,
132+
audio_length=all_results["audio_length_s"],
133+
transcription_time=all_results["transcription_time_s"],
134+
)
135+
print("Results saved at path:", os.path.abspath(manifest_path))
136+
137+
wer = wer_metric.compute(
138+
references=all_results["references"], predictions=all_results["predictions"]
139+
)
140+
wer = round(100 * wer, 2)
141+
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
142+
print("WER:", wer, "%", "RTFx:", rtfx)
143+
144+
145+
if __name__ == "__main__":
146+
parser = argparse.ArgumentParser()
147+
148+
parser.add_argument(
149+
"--model_id",
150+
type=str,
151+
required=True,
152+
help="Model identifier. Should be loadable with 🤗 Transformers",
153+
)
154+
parser.add_argument(
155+
"--dataset_path",
156+
type=str,
157+
default="esb/datasets",
158+
help="Dataset path. By default, it is `esb/datasets`",
159+
)
160+
parser.add_argument(
161+
"--dataset",
162+
type=str,
163+
required=True,
164+
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
165+
"can be found at `https://huggingface.co/datasets/esb/datasets`",
166+
)
167+
parser.add_argument(
168+
"--split",
169+
type=str,
170+
default="test",
171+
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
172+
)
173+
parser.add_argument(
174+
"--device",
175+
type=int,
176+
default=-1,
177+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
178+
)
179+
parser.add_argument(
180+
"--batch_size",
181+
type=int,
182+
default=16,
183+
help="Number of samples to go through each streamed batch.",
184+
)
185+
parser.add_argument(
186+
"--num_beams",
187+
type=int,
188+
default=1,
189+
help="Number of beams for beam search.",
190+
)
191+
parser.add_argument(
192+
"--max_eval_samples",
193+
type=int,
194+
default=None,
195+
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
196+
)
197+
parser.add_argument(
198+
"--no-streaming",
199+
dest="streaming",
200+
action="store_false",
201+
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
202+
)
203+
parser.add_argument(
204+
"--max_new_tokens",
205+
type=int,
206+
default=None,
207+
help="Maximum number of tokens to generate (for auto-regressive models).",
208+
)
209+
parser.add_argument(
210+
"--warmup_steps",
211+
type=int,
212+
default=2,
213+
help="Number of warm-up steps to run before launching the timed runs.",
214+
)
215+
216+
args = parser.parse_args()
217+
parser.set_defaults(streaming=False)
218+
219+
main(args)

granite/run_granite.sh

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#!/bin/bash
2+
3+
export PYTHONPATH="..":$PYTHONPATH
4+
5+
MODEL_IDs=(
6+
"ibm-granite/granite-speech-3.3-2b"
7+
"ibm-granite/granite-speech-3.3-8b"
8+
)
9+
10+
BATCH_SIZEs=(
11+
20
12+
12
13+
)
14+
15+
NUM_BEAMS=1
16+
MAX_NEW_TOKENS=200
17+
18+
num_models=${#MODEL_IDs[@]}
19+
20+
for (( i=0; i<${num_models}; i++ ));
21+
do
22+
MODEL_ID=${MODEL_IDs[$i]}
23+
BATCH_SIZE=${BATCH_SIZEs[$i]}
24+
25+
python run_eval.py \
26+
--model_id=${MODEL_ID} \
27+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
28+
--dataset="voxpopuli" \
29+
--split="test" \
30+
--device=0 \
31+
--batch_size=${BATCH_SIZE} \
32+
--num_beams=${NUM_BEAMS} \
33+
--max_eval_samples=-1 \
34+
--max_new_tokens=${MAX_NEW_TOKENS}
35+
36+
python run_eval.py \
37+
--model_id=${MODEL_ID} \
38+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
39+
--dataset="ami" \
40+
--split="test" \
41+
--device=0 \
42+
--batch_size=${BATCH_SIZE} \
43+
--num_beams=${NUM_BEAMS} \
44+
--max_eval_samples=-1 \
45+
--max_new_tokens=${MAX_NEW_TOKENS}
46+
47+
python run_eval.py \
48+
--model_id=${MODEL_ID} \
49+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
50+
--dataset="earnings22" \
51+
--split="test" \
52+
--device=0 \
53+
--batch_size=${BATCH_SIZE} \
54+
--num_beams=${NUM_BEAMS} \
55+
--max_eval_samples=-1 \
56+
--max_new_tokens=${MAX_NEW_TOKENS}
57+
58+
python run_eval.py \
59+
--model_id=${MODEL_ID} \
60+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
61+
--dataset="gigaspeech" \
62+
--split="test" \
63+
--device=0 \
64+
--batch_size=${BATCH_SIZE} \
65+
--num_beams=${NUM_BEAMS} \
66+
--max_eval_samples=-1 \
67+
--max_new_tokens=${MAX_NEW_TOKENS}
68+
69+
python run_eval.py \
70+
--model_id=${MODEL_ID} \
71+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
72+
--dataset="librispeech" \
73+
--split="test.clean" \
74+
--device=0 \
75+
--batch_size=${BATCH_SIZE} \
76+
--num_beams=${NUM_BEAMS} \
77+
--max_eval_samples=-1 \
78+
--max_new_tokens=${MAX_NEW_TOKENS}
79+
80+
python run_eval.py \
81+
--model_id=${MODEL_ID} \
82+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
83+
--dataset="librispeech" \
84+
--split="test.other" \
85+
--device=0 \
86+
--batch_size=${BATCH_SIZE} \
87+
--num_beams=${NUM_BEAMS} \
88+
--max_eval_samples=-1 \
89+
--max_new_tokens=${MAX_NEW_TOKENS}
90+
91+
python run_eval.py \
92+
--model_id=${MODEL_ID} \
93+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
94+
--dataset="spgispeech" \
95+
--split="test" \
96+
--device=0 \
97+
--batch_size=${BATCH_SIZE} \
98+
--num_beams=${NUM_BEAMS} \
99+
--max_eval_samples=-1 \
100+
--max_new_tokens=${MAX_NEW_TOKENS}
101+
102+
python run_eval.py \
103+
--model_id=${MODEL_ID} \
104+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
105+
--dataset="tedlium" \
106+
--split="test" \
107+
--device=0 \
108+
--batch_size=${BATCH_SIZE} \
109+
--num_beams=${NUM_BEAMS} \
110+
--max_eval_samples=-1 \
111+
--max_new_tokens=${MAX_NEW_TOKENS}
112+
113+
# Evaluate results
114+
RUNDIR=`pwd` && \
115+
cd ../normalizer && \
116+
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \
117+
cd $RUNDIR
118+
119+
done
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
evaluate
2+
datasets==3.4.1
3+
peft==0.13.1
4+
torch==2.5.1
5+
torchaudio==2.5.1
6+
transformers @ https://github.com/huggingface/transformers/archive/main.zip
7+
soundfile

0 commit comments

Comments
 (0)