Skip to content

Commit 0b4d0bc

Browse files
committed
init
1 parent 692825d commit 0b4d0bc

File tree

5 files changed

+36
-11
lines changed

5 files changed

+36
-11
lines changed

.ci/scripts/export_model_cuda_artifact.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Arguments:
1818
Supported models:
1919
- mistralai/Voxtral-Mini-3B-2507
2020
- openai/whisper-small
21+
- openai/whisper-large-v2
2122
- google/gemma-3-4b-it
2223
2324
quant_name Quantization type (optional, default: non-quantized)
@@ -62,7 +63,7 @@ case "$HF_MODEL" in
6263
PREPROCESSOR_FEATURE_SIZE="128"
6364
PREPROCESSOR_OUTPUT="voxtral_preprocessor.pte"
6465
;;
65-
openai/whisper-small)
66+
openai/whisper-*)
6667
MODEL_NAME="whisper"
6768
TASK="automatic-speech-recognition"
6869
MAX_SEQ_LEN=""
@@ -80,7 +81,7 @@ case "$HF_MODEL" in
8081
;;
8182
*)
8283
echo "Error: Unsupported model '$HF_MODEL'"
83-
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-small, google/gemma-3-4b-it"
84+
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-small, openai/whisper-large-v2, google/gemma-3-4b-it"
8485
exit 1
8586
;;
8687
esac

.ci/scripts/test_model_cuda_e2e.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Arguments:
1818
Supported models:
1919
- mistralai/Voxtral-Mini-3B-2507
2020
- openai/whisper-small
21+
- openai/whisper-large-v2
2122
- google/gemma-3-4b-it
2223
2324
quant_name Quantization type (required)
@@ -91,13 +92,13 @@ case "$HF_MODEL" in
9192
AUDIO_FILE="poem.wav"
9293
IMAGE_PATH=""
9394
;;
94-
openai/whisper-small)
95-
MODEL_NAME="whisper"
95+
openai/whisper-*)
96+
MODEL_NAME="${HF_MODEL#openai/}"
9697
RUNNER_TARGET="whisper_runner"
9798
RUNNER_PATH="whisper"
9899
EXPECTED_OUTPUT="Mr. Quilter is the apostle of the middle classes"
99100
PREPROCESSOR="whisper_preprocessor.pte"
100-
TOKENIZER_URL="https://huggingface.co/openai/whisper-small/resolve/main" # @lint-ignore
101+
TOKENIZER_URL="https://huggingface.co/${HF_MODEL}/resolve/main" # @lint-ignore
101102
TOKENIZER_FILE=""
102103
AUDIO_URL=""
103104
AUDIO_FILE="output.wav"
@@ -117,7 +118,7 @@ case "$HF_MODEL" in
117118
;;
118119
*)
119120
echo "Error: Unsupported model '$HF_MODEL'"
120-
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-small, google/gemma-3-4b-it"
121+
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-small, openai/whisper-large-v2, google/gemma-3-4b-it"
121122
exit 1
122123
;;
123124
esac
@@ -142,7 +143,7 @@ fi
142143
# Download test files
143144
if [ "$AUDIO_URL" != "" ]; then
144145
curl -L $AUDIO_URL -o ${MODEL_DIR}/$AUDIO_FILE
145-
elif [ "$MODEL_NAME" = "whisper" ]; then
146+
elif [[ "$MODEL_NAME" == *whisper* ]]; then
146147
conda install -y -c conda-forge "ffmpeg<8"
147148
pip install datasets soundfile torchcodec
148149
python -c "from datasets import load_dataset;import soundfile as sf;sample = load_dataset('distil-whisper/librispeech_long', 'clean', split='validation')[0]['audio'];sf.write('${MODEL_DIR}/$AUDIO_FILE', sample['array'][:sample['sampling_rate']*30], sample['sampling_rate'])"
@@ -180,7 +181,7 @@ case "$MODEL_NAME" in
180181
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR"
181182
;;
182183
whisper)
183-
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR"
184+
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR --model_name ${WHISPER_MODEL_NAME}"
184185
;;
185186
gemma3)
186187
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --image_path $IMAGE_PATH"

.github/workflows/cuda.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ jobs:
104104
name: "Voxtral-Mini-3B-2507"
105105
- repo: "openai"
106106
name: "whisper-small"
107+
- repo: "openai"
108+
name: "whisper-large-v2"
107109
- repo: "google"
108110
name: "gemma-3-4b-it"
109111
quant:
@@ -223,6 +225,8 @@ jobs:
223225
name: "Voxtral-Mini-3B-2507"
224226
- repo: "openai"
225227
name: "whisper-small"
228+
- repo: "openai"
229+
name: "whisper-large-v2"
226230
- repo: "google"
227231
name: "gemma-3-4b-it"
228232
quant:

examples/models/whisper/main.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ DEFINE_string(
3939
audio_path,
4040
"",
4141
"Path to input audio file. Accepts .wav or raw float .bin.");
42+
DEFINE_string(
43+
model_name,
44+
"base",
45+
"Whisper model name (base, small, medium, large, large-v2, large-v3, turbo).");
4246
DEFINE_double(
4347
temperature,
4448
0.0,
@@ -109,7 +113,22 @@ int main(int argc, char** argv) {
109113
executorch::extension::asr::AsrTranscribeConfig config;
110114
config.max_new_tokens = FLAGS_max_new_tokens;
111115
config.temperature = static_cast<float>(FLAGS_temperature);
112-
config.decoder_start_token_id = 50257;
116+
117+
// Set decoder_start_token_id based on model version
118+
if (FLAGS_model_name == "large-v2" || FLAGS_model_name == "large-v3" ||
119+
FLAGS_model_name == "turbo") {
120+
config.decoder_start_token_id = 50258;
121+
ET_LOG(
122+
Info,
123+
"Using decoder_start_token_id=50258 for model: %s",
124+
FLAGS_model_name.c_str());
125+
} else {
126+
config.decoder_start_token_id = 50257;
127+
ET_LOG(
128+
Info,
129+
"Using decoder_start_token_id=50257 for model: %s",
130+
FLAGS_model_name.c_str());
131+
}
113132

114133
auto result =
115134
runner.transcribe(features, config, [&](const std::string& piece) {

extension/asr/runner/runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
193193
"Conversion complete, first value = %f",
194194
static_cast<float>(
195195
preprocessed_features
196-
->mutable_data_ptr<::executorch::aten::BFloat16>()[0]));
196+
->mutable_data_ptr<float>()[0]));
197197
}
198198
}
199199

@@ -225,7 +225,7 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
225225
"Encoder first value: %f",
226226
static_cast<float>(
227227
encoder_output_tensor
228-
.mutable_data_ptr<::executorch::aten::BFloat16>()[0]));
228+
.mutable_data_ptr<float>()[0]));
229229

230230
auto encoder_output_ptr = std::make_shared<::executorch::aten::Tensor>(
231231
std::move(encoder_output_tensor));

0 commit comments

Comments
 (0)