Skip to content

Commit b4de2c9

Browse files
Yiming Wangfreewym
authored andcommitted
Adds Phi-4-Multimodal
1 parent 067b21c commit b4de2c9

File tree

2 files changed

+358
-0
lines changed

2 files changed

+358
-0
lines changed

phi/run_eval.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import argparse
2+
import os
3+
import torch
4+
from transformers import AutoModelForCausalLM, AutoProcessor, StoppingCriteria, StoppingCriteriaList
5+
import evaluate
6+
from normalizer import data_utils
7+
import time
8+
from tqdm import tqdm
9+
10+
wer_metric = evaluate.load("wer")
11+
torch.set_float32_matmul_precision('high')
12+
13+
class MultipleTokenBatchStoppingCriteria(StoppingCriteria):
14+
"""Stopping criteria capable of receiving multiple stop-tokens and handling batched inputs."""
15+
16+
def __init__(self, stop_tokens: torch.LongTensor, batch_size: int = 1) -> None:
17+
"""Initialize the multiple token batch stopping criteria.
18+
19+
Args:
20+
stop_tokens: Stop-tokens.
21+
batch_size: Batch size.
22+
23+
"""
24+
25+
self.stop_tokens = stop_tokens
26+
self.max_stop_tokens = stop_tokens.shape[-1]
27+
self.stop_tokens_idx = torch.zeros(batch_size, dtype=torch.long, device=stop_tokens.device)
28+
29+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
30+
# Only gather the maximum number of inputs compatible with stop tokens
31+
# and checks whether generated inputs are equal to `stop_tokens`
32+
generated_inputs = torch.eq(input_ids[:, -self.max_stop_tokens :].unsqueeze(1), self.stop_tokens)
33+
equal_generated_inputs = torch.all(generated_inputs, dim=2)
34+
35+
# Mark the position where a stop token has been produced for each input in the batch,
36+
# but only if the corresponding entry is not already set
37+
sequence_idx = torch.any(equal_generated_inputs, dim=1)
38+
sequence_set_mask = self.stop_tokens_idx == 0
39+
self.stop_tokens_idx[sequence_idx & sequence_set_mask] = input_ids.shape[-1]
40+
41+
return torch.all(self.stop_tokens_idx)
42+
43+
44+
def main(args):
45+
model = AutoModelForCausalLM.from_pretrained(
46+
args.model_id,
47+
trust_remote_code=True,
48+
torch_dtype="auto",
49+
_attn_implementation="flash_attention_2",
50+
).to(args.device)
51+
model.eval()
52+
processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
53+
54+
user = "<|user|>"
55+
assistant = "<|assistant|>"
56+
prompt_suffix = "<|end|>"
57+
58+
prompt = f"{user}<|audio_1|>{args.user_prompt}{prompt_suffix}{assistant}"
59+
60+
gen_kwargs = {"max_new_tokens": args.max_new_tokens}
61+
62+
stop_tokens = [prompt_suffix, processor.tokenizer.eos_token]
63+
stop_tokens_ids = processor.tokenizer(stop_tokens, add_special_tokens=False, padding="longest", return_tensors="pt")["input_ids"]
64+
stop_tokens_ids = stop_tokens_ids.to(model.device)
65+
66+
def benchmark(batch, min_new_tokens=None):
67+
# Load audio inputs
68+
audios = [(audio["array"], audio["sampling_rate"]) for audio in batch["audio"]]
69+
minibatch_size = len(audios)
70+
gen_kwargs["stopping_criteria"] = StoppingCriteriaList([MultipleTokenBatchStoppingCriteria(stop_tokens_ids, batch_size=minibatch_size)])
71+
72+
# START TIMING
73+
start_time = time.time()
74+
75+
with torch.autocast(model.device.type, enabled=True):
76+
inputs = processor(text=[prompt] * minibatch_size, audios=audios, return_tensors="pt").to(args.device)
77+
78+
# Model Inference
79+
pred_ids = model.generate(
80+
**inputs,
81+
pad_token_id=processor.tokenizer.pad_token_id,
82+
eos_token_id=processor.tokenizer.eos_token_id,
83+
**gen_kwargs,
84+
min_new_tokens=min_new_tokens,
85+
)
86+
87+
# Gather the sequence index of the stop token
88+
stop_tokens_idx = gen_kwargs["stopping_criteria"][0].stop_tokens_idx.reshape(minibatch_size, -1)[:, 0]
89+
90+
# If a stop token was produced, we need to remove its length from the found index,
91+
# however there might be a chance that the stop token was not produced and the index
92+
# returned is the length of the generated sequence
93+
stop_tokens_idx = torch.where(
94+
stop_tokens_idx > 0,
95+
stop_tokens_idx - stop_tokens_ids.shape[-1],
96+
pred_ids.shape[-1],
97+
)
98+
99+
# Convert token ids to text transcription
100+
pred_text = [
101+
processor.decode(_pred_ids[inputs["input_ids"].shape[1] : _stop_tokens_idx], skip_special_tokens=True, clean_up_tokenization_spaces=False)
102+
for _pred_ids, _stop_tokens_idx in zip(pred_ids, stop_tokens_idx)
103+
]
104+
105+
# END TIMING
106+
runtime = time.time() - start_time
107+
108+
# normalize by minibatch size since we want the per-sample time
109+
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
110+
111+
# normalize transcriptions with English normalizer
112+
batch["predictions"] = [data_utils.normalizer(pred) for pred in pred_text]
113+
batch["references"] = batch["norm_text"]
114+
return batch
115+
116+
if args.warmup_steps is not None:
117+
dataset = data_utils.load_data(args)
118+
dataset = data_utils.prepare_data(dataset)
119+
120+
num_warmup_samples = args.warmup_steps * args.batch_size
121+
if args.streaming:
122+
warmup_dataset = dataset.take(num_warmup_samples)
123+
else:
124+
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
125+
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens}))
126+
127+
for _ in tqdm(warmup_dataset, desc="Warming up..."):
128+
continue
129+
130+
dataset = data_utils.load_data(args)
131+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
132+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
133+
if args.streaming:
134+
dataset = dataset.take(args.max_eval_samples)
135+
else:
136+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
137+
dataset = data_utils.prepare_data(dataset)
138+
139+
dataset = dataset.map(
140+
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
141+
)
142+
143+
all_results = {
144+
"audio_length_s": [],
145+
"transcription_time_s": [],
146+
"predictions": [],
147+
"references": [],
148+
}
149+
result_iter = iter(dataset)
150+
for result in tqdm(result_iter, desc="Samples..."):
151+
for key in all_results:
152+
all_results[key].append(result[key])
153+
154+
# Write manifest results (WER and RTFX)
155+
manifest_path = data_utils.write_manifest(
156+
all_results["references"],
157+
all_results["predictions"],
158+
args.model_id,
159+
args.dataset_path,
160+
args.dataset,
161+
args.split,
162+
audio_length=all_results["audio_length_s"],
163+
transcription_time=all_results["transcription_time_s"],
164+
)
165+
print("Results saved at path:", os.path.abspath(manifest_path))
166+
167+
wer = wer_metric.compute(
168+
references=all_results["references"], predictions=all_results["predictions"]
169+
)
170+
wer = round(100 * wer, 2)
171+
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
172+
print("WER:", wer, "%", "RTFx:", rtfx)
173+
174+
175+
if __name__ == "__main__":
176+
parser = argparse.ArgumentParser()
177+
178+
parser.add_argument(
179+
"--model_id",
180+
type=str,
181+
required=True,
182+
help="Model identifier. Should be loadable with 🤗 Transformers",
183+
)
184+
parser.add_argument(
185+
"--dataset_path",
186+
type=str,
187+
default="esb/datasets",
188+
help="Dataset path. By default, it is `esb/datasets`",
189+
)
190+
parser.add_argument(
191+
"--dataset",
192+
type=str,
193+
required=True,
194+
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
195+
"can be found at `https://huggingface.co/datasets/esb/datasets`",
196+
)
197+
parser.add_argument(
198+
"--split",
199+
type=str,
200+
default="test",
201+
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
202+
)
203+
parser.add_argument(
204+
"--device",
205+
type=int,
206+
default=-1,
207+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
208+
)
209+
parser.add_argument(
210+
"--batch_size",
211+
type=int,
212+
default=16,
213+
help="Number of samples to go through each streamed batch.",
214+
)
215+
parser.add_argument(
216+
"--max_eval_samples",
217+
type=int,
218+
default=None,
219+
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
220+
)
221+
parser.add_argument(
222+
"--no-streaming",
223+
dest="streaming",
224+
action="store_false",
225+
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
226+
)
227+
parser.add_argument(
228+
"--max_new_tokens",
229+
type=int,
230+
default=None,
231+
help="Maximum number of tokens to generate (for auto-regressive models).",
232+
)
233+
parser.add_argument(
234+
"--warmup_steps",
235+
type=int,
236+
default=2,
237+
help="Number of warm-up steps to run before launching the timed runs.",
238+
)
239+
parser.add_argument(
240+
"--user_prompt",
241+
type=str,
242+
default="Transcribe the audio clip into text.",
243+
help="User prompt string.",
244+
)
245+
args = parser.parse_args()
246+
parser.set_defaults(streaming=False)
247+
248+
main(args)

phi/run_phi4_multimodal.sh

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#!/bin/bash
2+
3+
export PYTHONPATH="..":$PYTHONPATH
4+
5+
MODEL_IDs=("microsoft/Phi-4-multimodal-instruct")
6+
BATCH_SIZE=32
7+
MAX_NEW_TOKENS=512
8+
9+
num_models=${#MODEL_IDs[@]}
10+
default_user_prompt="Transcribe the audio clip into text."
11+
12+
for (( i=0; i<${num_models}; i++ ));
13+
do
14+
MODEL_ID=${MODEL_IDs[$i]}
15+
16+
python run_eval.py \
17+
--model_id=${MODEL_ID} \
18+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
19+
--dataset="voxpopuli" \
20+
--split="test" \
21+
--device=0 \
22+
--batch_size=${BATCH_SIZE} \
23+
--max_eval_samples=-1 \
24+
--max_new_tokens=${MAX_NEW_TOKENS} \
25+
--user_prompt="${default_user_prompt}"
26+
27+
python run_eval.py \
28+
--model_id=${MODEL_ID} \
29+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
30+
--dataset="ami" \
31+
--split="test" \
32+
--device=0 \
33+
--batch_size=${BATCH_SIZE} \
34+
--max_eval_samples=-1 \
35+
--max_new_tokens=${MAX_NEW_TOKENS} \
36+
--user_prompt="${default_user_prompt}"
37+
38+
python run_eval.py \
39+
--model_id=${MODEL_ID} \
40+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
41+
--dataset="earnings22" \
42+
--split="test" \
43+
--device=0 \
44+
--batch_size=${BATCH_SIZE} \
45+
--max_eval_samples=-1 \
46+
--max_new_tokens=${MAX_NEW_TOKENS} \
47+
--user_prompt="Transcribe the audio clip to English text."
48+
49+
python run_eval.py \
50+
--model_id=${MODEL_ID} \
51+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
52+
--dataset="gigaspeech" \
53+
--split="test" \
54+
--device=0 \
55+
--batch_size=${BATCH_SIZE} \
56+
--max_eval_samples=-1 \
57+
--max_new_tokens=${MAX_NEW_TOKENS} \
58+
--user_prompt="${default_user_prompt}"
59+
60+
python run_eval.py \
61+
--model_id=${MODEL_ID} \
62+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
63+
--dataset="librispeech" \
64+
--split="test.clean" \
65+
--device=0 \
66+
--batch_size=${BATCH_SIZE} \
67+
--max_eval_samples=-1 \
68+
--max_new_tokens=${MAX_NEW_TOKENS} \
69+
--user_prompt="${default_user_prompt}"
70+
71+
python run_eval.py \
72+
--model_id=${MODEL_ID} \
73+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
74+
--dataset="librispeech" \
75+
--split="test.other" \
76+
--device=0 \
77+
--batch_size=${BATCH_SIZE} \
78+
--max_eval_samples=-1 \
79+
--max_new_tokens=${MAX_NEW_TOKENS} \
80+
--user_prompt="${default_user_prompt}"
81+
82+
python run_eval.py \
83+
--model_id=${MODEL_ID} \
84+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
85+
--dataset="spgispeech" \
86+
--split="test" \
87+
--device=0 \
88+
--batch_size=${BATCH_SIZE} \
89+
--max_eval_samples=-1 \
90+
--max_new_tokens=${MAX_NEW_TOKENS} \
91+
--user_prompt="${default_user_prompt}"
92+
93+
python run_eval.py \
94+
--model_id=${MODEL_ID} \
95+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
96+
--dataset="tedlium" \
97+
--split="test" \
98+
--device=0 \
99+
--batch_size=${BATCH_SIZE} \
100+
--max_eval_samples=-1 \
101+
--max_new_tokens=${MAX_NEW_TOKENS} \
102+
--user_prompt="${default_user_prompt}"
103+
104+
# Evaluate results
105+
RUNDIR=`pwd` && \
106+
cd ../normalizer && \
107+
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \
108+
cd $RUNDIR
109+
110+
done

0 commit comments

Comments
 (0)