Skip to content

Commit 4d5f330

Browse files
authored
parakeet cuda mode works (pytorch#16674)
When exporting with Dim.AUTO, GPU-enabled PyTorch adds CUDA-specific guards for convolution kernel based on cuDNN workspace calculations, which constrain the dynamic dimension based on the sample tensor size. With the original miscalculated 100-frame sample, this created a guard limiting inputs to ~160 mel frames (~1.6 sec), causing runtime failures for longer audio. On Mac/CPU-only PyTorch, these cuDNN guards are never added since no CUDA backend is selected, so NeMo's internal limit of 5000 frames is preserved so that we could have correct inference. This PR fix the issue by making sample tensor matching the desired max audio duration (max_mel_frames), ensuring the cuDNN guard accommodates the full input range. Also introduces ci for parakeet running on cuda backend.
1 parent 59aa69d commit 4d5f330

File tree

5 files changed

+157
-34
lines changed

5 files changed

+157
-34
lines changed

.ci/scripts/export_model_artifact.sh

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Arguments:
2121
- mistralai/Voxtral-Mini-3B-2507
2222
- openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo})
2323
- google/gemma-3-4b-it
24+
- nvidia/parakeet-tdt
2425
2526
quant_name Quantization type (optional, default: non-quantized)
2627
Options:
@@ -34,6 +35,7 @@ Examples:
3435
export_model_artifact.sh metal "openai/whisper-small"
3536
export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed"
3637
export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output"
38+
export_model_artifact.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./output"
3739
EOF
3840
}
3941

@@ -101,9 +103,21 @@ case "$HF_MODEL" in
101103
PREPROCESSOR_FEATURE_SIZE=""
102104
PREPROCESSOR_OUTPUT=""
103105
;;
106+
nvidia/parakeet-tdt)
107+
if [ "$DEVICE" = "metal" ]; then
108+
echo "Error: Export for device 'metal' is not yet tested for model '$HF_MODEL'"
109+
exit 1
110+
fi
111+
MODEL_NAME="parakeet"
112+
TASK=""
113+
MAX_SEQ_LEN=""
114+
EXTRA_PIP=""
115+
PREPROCESSOR_FEATURE_SIZE=""
116+
PREPROCESSOR_OUTPUT=""
117+
;;
104118
*)
105119
echo "Error: Unsupported model '$HF_MODEL'"
106-
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it"
120+
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, nvidia/parakeet-tdt"
107121
exit 1
108122
;;
109123
esac
@@ -141,6 +155,22 @@ if [ -n "$EXTRA_PIP" ]; then
141155
fi
142156
pip list
143157

158+
# Parakeet uses a custom export script
159+
if [ "$MODEL_NAME" = "parakeet" ]; then
160+
pip install -r examples/models/parakeet/install_requirements.txt
161+
162+
python examples/models/parakeet/export_parakeet_tdt.py \
163+
--backend "$DEVICE" \
164+
--output-dir "${OUTPUT_DIR}"
165+
166+
test -f "${OUTPUT_DIR}/model.pte"
167+
test -f "${OUTPUT_DIR}/aoti_${DEVICE}_blob.ptd"
168+
test -f "${OUTPUT_DIR}/tokenizer.model"
169+
ls -al "${OUTPUT_DIR}"
170+
echo "::endgroup::"
171+
exit 0
172+
fi
173+
144174
MAX_SEQ_LEN_ARG=""
145175
if [ -n "$MAX_SEQ_LEN" ]; then
146176
MAX_SEQ_LEN_ARG="--max_seq_len $MAX_SEQ_LEN"

.ci/scripts/test_model_e2e.sh

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Arguments:
2121
- mistralai/Voxtral-Mini-3B-2507
2222
- openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo})
2323
- google/gemma-3-4b-it
24+
- nvidia/parakeet-tdt
2425
2526
quant_name Quantization type (required)
2627
Options:
@@ -35,6 +36,7 @@ Arguments:
3536
Examples:
3637
test_model_e2e.sh metal "openai/whisper-small" "non-quantized"
3738
test_model_e2e.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" "./model_output"
39+
test_model_e2e.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./model_output"
3840
EOF
3941
}
4042

@@ -118,9 +120,21 @@ case "$HF_MODEL" in
118120
AUDIO_FILE=""
119121
IMAGE_PATH="docs/source/_static/img/et-logo.png"
120122
;;
123+
nvidia/parakeet-tdt)
124+
MODEL_NAME="parakeet"
125+
RUNNER_TARGET="parakeet_runner"
126+
RUNNER_PATH="parakeet"
127+
EXPECTED_OUTPUT="Phoebe"
128+
PREPROCESSOR=""
129+
TOKENIZER_URL=""
130+
TOKENIZER_FILE="tokenizer.model"
131+
AUDIO_URL="https://dldata-public.s3.us-east-2.amazonaws.com/2086-149220-0033.wav"
132+
AUDIO_FILE="test_audio.wav"
133+
IMAGE_PATH=""
134+
;;
121135
*)
122136
echo "Error: Unsupported model '$HF_MODEL'"
123-
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it"
137+
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, nvidia/parakeet-tdt"
124138
exit 1
125139
;;
126140
esac
@@ -133,13 +147,15 @@ echo "::endgroup::"
133147
echo "::group::Prepare $MODEL_NAME Artifacts"
134148

135149

136-
# Download tokenizer files
137-
if [ "$TOKENIZER_FILE" != "" ]; then
138-
curl -L $TOKENIZER_URL/$TOKENIZER_FILE -o $MODEL_DIR/$TOKENIZER_FILE
139-
else
140-
curl -L $TOKENIZER_URL/tokenizer.json -o $MODEL_DIR/tokenizer.json
141-
curl -L $TOKENIZER_URL/tokenizer_config.json -o $MODEL_DIR/tokenizer_config.json
142-
curl -L $TOKENIZER_URL/special_tokens_map.json -o $MODEL_DIR/special_tokens_map.json
150+
# Download tokenizer files (skip for parakeet which exports tokenizer with model)
151+
if [ "$MODEL_NAME" != "parakeet" ]; then
152+
if [ "$TOKENIZER_FILE" != "" ]; then
153+
curl -L $TOKENIZER_URL/$TOKENIZER_FILE -o $MODEL_DIR/$TOKENIZER_FILE
154+
else
155+
curl -L $TOKENIZER_URL/tokenizer.json -o $MODEL_DIR/tokenizer.json
156+
curl -L $TOKENIZER_URL/tokenizer_config.json -o $MODEL_DIR/tokenizer_config.json
157+
curl -L $TOKENIZER_URL/special_tokens_map.json -o $MODEL_DIR/special_tokens_map.json
158+
fi
143159
fi
144160

145161
# Download test files
@@ -187,23 +203,34 @@ case "$MODEL_NAME" in
187203
gemma3)
188204
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --image_path $IMAGE_PATH"
189205
;;
206+
parakeet)
207+
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --data_path ${MODEL_DIR}/aoti_${DEVICE}_blob.ptd --audio_path ${MODEL_DIR}/$AUDIO_FILE --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE"
208+
;;
190209
esac
191210

192211
OUTPUT=$($RUNNER_BIN $RUNNER_ARGS 2>&1)
193212
EXIT_CODE=$?
194213
set -e
195214

196-
if ! echo "$OUTPUT" | grep -iq "$EXPECTED_OUTPUT"; then
197-
echo "Expected output '$EXPECTED_OUTPUT' not found in output"
198-
exit 1
199-
else
200-
echo "Success: '$EXPECTED_OUTPUT' found in output"
201-
fi
215+
echo "Runner output:"
216+
echo "$OUTPUT"
202217

203218
if [ $EXIT_CODE -ne 0 ]; then
204219
echo "Unexpected exit code: $EXIT_CODE"
205220
exit $EXIT_CODE
206221
fi
222+
223+
# Validate output for models that have expected output
224+
if [ -n "$EXPECTED_OUTPUT" ]; then
225+
if ! echo "$OUTPUT" | grep -iq "$EXPECTED_OUTPUT"; then
226+
echo "Expected output '$EXPECTED_OUTPUT' not found in output"
227+
exit 1
228+
else
229+
echo "Success: '$EXPECTED_OUTPUT' found in output"
230+
fi
231+
else
232+
echo "SUCCESS: Runner completed successfully"
233+
fi
207234
echo "::endgroup::"
208235

209236
popd

.github/workflows/cuda.yml

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ jobs:
138138
name: "whisper-large-v3-turbo"
139139
- repo: "google"
140140
name: "gemma-3-4b-it"
141+
- repo: "nvidia"
142+
name: "parakeet-tdt"
141143
quant:
142144
- "non-quantized"
143145
- "quantized-int4-tile-packed"
@@ -148,6 +150,15 @@ jobs:
148150
repo: "google"
149151
name: "gemma-3-4b-it"
150152
quant: "quantized-int4-weight-only"
153+
# Parakeet only supports non-quantized
154+
- model:
155+
repo: "nvidia"
156+
name: "parakeet-tdt"
157+
quant: "quantized-int4-tile-packed"
158+
- model:
159+
repo: "nvidia"
160+
name: "parakeet-tdt"
161+
quant: "quantized-int4-weight-only"
151162
with:
152163
timeout: 90
153164
secrets-env: EXECUTORCH_HF_TOKEN
@@ -165,12 +176,15 @@ jobs:
165176
./install_executorch.sh
166177
echo "::endgroup::"
167178
168-
echo "::group::Setup Huggingface"
169-
pip install -U "huggingface_hub[cli]<1.0" accelerate
170-
huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
171-
OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt)
172-
pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}
173-
echo "::endgroup::"
179+
# Setup Huggingface only for models that need it (not parakeet)
180+
if [ "${{ matrix.model.name }}" != "parakeet-tdt" ]; then
181+
echo "::group::Setup Huggingface"
182+
pip install -U "huggingface_hub[cli]<1.0" accelerate
183+
huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
184+
OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt)
185+
pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}
186+
echo "::endgroup::"
187+
fi
174188
175189
source .ci/scripts/export_model_artifact.sh cuda "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}"
176190
@@ -193,6 +207,8 @@ jobs:
193207
name: "whisper-large-v3-turbo"
194208
- repo: "google"
195209
name: "gemma-3-4b-it"
210+
- repo: "nvidia"
211+
name: "parakeet-tdt"
196212
quant:
197213
- "non-quantized"
198214
- "quantized-int4-tile-packed"
@@ -203,6 +219,15 @@ jobs:
203219
repo: "google"
204220
name: "gemma-3-4b-it"
205221
quant: "quantized-int4-weight-only"
222+
# Parakeet only supports non-quantized
223+
- model:
224+
repo: "nvidia"
225+
name: "parakeet-tdt"
226+
quant: "quantized-int4-tile-packed"
227+
- model:
228+
repo: "nvidia"
229+
name: "parakeet-tdt"
230+
quant: "quantized-int4-weight-only"
206231
with:
207232
timeout: 90
208233
runner: linux.g5.4xlarge.nvidia.gpu

examples/models/parakeet/README.md

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,21 @@ python export_parakeet_tdt.py --backend metal --output-dir ./parakeet_metal
3838
```
3939

4040
This generates:
41-
- `parakeet_tdt.pte` - The compiled model
41+
- `model.pte` - The compiled Parakeet TDT model
4242
- `aoti_metal_blob.ptd` - Metal kernel blob required at runtime
4343
- `tokenizer.model` - SentencePiece tokenizer
4444

45+
### CUDA Export (Linux)
46+
47+
```bash
48+
python export_parakeet_tdt.py --backend cuda --output-dir ./parakeet_cuda
49+
```
50+
51+
This generates:
52+
- `model.pte` - The compiled Parakeet TDT model
53+
- `aoti_cuda_blob.ptd` - CUDA kernel blob required at runtime
54+
- `tokenizer.model` - SentencePiece tokenizer
55+
4556
## C++ Runner
4657

4758
### Building
@@ -55,7 +66,7 @@ make parakeet-cpu
5566
# Metal build (macOS)
5667
make parakeet-metal
5768

58-
# CUDA build (Linux/Windows)
69+
# CUDA build (Linux)
5970
make parakeet-cuda
6071
```
6172

@@ -66,16 +77,23 @@ From the executorch root directory:
6677
```bash
6778
# CPU/XNNPACK
6879
./cmake-out/examples/models/parakeet/parakeet_runner \
69-
--model_path examples/models/parakeet/parakeet_tdt_exports/parakeet_tdt.pte \
80+
--model_path examples/models/parakeet/parakeet_tdt_exports/model.pte \
7081
--audio_path /path/to/audio.wav \
7182
--tokenizer_path examples/models/parakeet/parakeet_tdt_exports/tokenizer.model
7283

7384
# Metal (include .ptd data file)
7485
DYLD_LIBRARY_PATH=/usr/lib ./cmake-out/examples/models/parakeet/parakeet_runner \
75-
--model_path examples/models/parakeet/parakeet_metal/parakeet_tdt.pte \
86+
--model_path examples/models/parakeet/parakeet_metal/model.pte \
7687
--data_path examples/models/parakeet/parakeet_metal/aoti_metal_blob.ptd \
7788
--audio_path /path/to/audio.wav \
7889
--tokenizer_path examples/models/parakeet/parakeet_metal/tokenizer.model
90+
91+
# CUDA (include .ptd data file)
92+
./cmake-out/examples/models/parakeet/parakeet_runner \
93+
--model_path examples/models/parakeet/parakeet_cuda/model.pte \
94+
--data_path examples/models/parakeet/parakeet_cuda/aoti_cuda_blob.ptd \
95+
--audio_path /path/to/audio.wav \
96+
--tokenizer_path examples/models/parakeet/parakeet_cuda/tokenizer.model
7997
```
8098

8199
### Runner Arguments

examples/models/parakeet/export_parakeet_tdt.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import tempfile
88

99
import torch
10-
1110
import torchaudio
1211
from executorch.exir import (
1312
EdgeCompileConfig,
@@ -297,37 +296,61 @@ def forward(
297296

298297

299298
def export_all(model):
299+
"""Export all model components.
300+
301+
The maximum audio duration is determined by the model's internal
302+
max_audio_length (~50 seconds for Parakeet with max_audio_length=5000).
303+
"""
300304
programs = {}
301305

306+
# Get audio parameters from model config
307+
sample_rate = model.preprocessor._cfg.sample_rate
308+
window_stride = float(model.preprocessor._cfg.window_stride)
309+
310+
# Get encoder's actual limit from NeMo model
311+
encoder_max_frames = model.encoder.max_audio_length # typically 5000
312+
max_audio_sec = int(encoder_max_frames * window_stride)
313+
314+
max_audio_samples = int(sample_rate * max_audio_sec)
315+
max_mel_frames = int(max_audio_sec / window_stride)
316+
302317
preprocessor_wrapper = PreprocessorWrapper(model.preprocessor)
303318
preprocessor_wrapper.eval()
304-
sample_audio = torch.randn(16000 * 10)
319+
sample_audio = torch.randn(max_audio_samples)
305320
sample_length = torch.tensor([sample_audio.shape[0]], dtype=torch.int64)
306-
# The preprocessor definition changes if cuda is available (likely due to making it cuda graphable).
307-
# Unfortunately that new definition is not supported by export, so we need to stop that from happening.
321+
# The preprocessor uses different code paths when CUDA is available, which include
322+
# data-dependent conditionals that torch.export cannot handle. Force CPU path.
308323
old_cuda_is_available = torch.cuda.is_available
309324
torch.cuda.is_available = lambda: False
310325
programs["preprocessor"] = export(
311326
preprocessor_wrapper,
312327
(sample_audio, sample_length),
313328
dynamic_shapes={
314-
"audio": {0: Dim("audio_len", min=1600, max=16000 * 600)},
329+
# min=1600 samples = 0.1 sec @ 16kHz, max aligned with encoder limit
330+
"audio": {0: Dim("audio_len", min=1600, max=max_audio_samples)},
315331
"length": {},
316332
},
317333
strict=False,
318334
)
319335
torch.cuda.is_available = old_cuda_is_available
320336

321337
feat_in = getattr(model.encoder, "_feat_in", 128)
322-
audio_signal = torch.randn(1, feat_in, 100)
323-
length = torch.tensor([100], dtype=torch.int64)
338+
# Use max_mel_frames as example to ensure Dim.AUTO infers the full range.
339+
# Smaller examples cause Dim.AUTO to infer narrow bounds.
340+
audio_signal = torch.randn(1, feat_in, max_mel_frames)
341+
length = torch.tensor([max_mel_frames], dtype=torch.int64)
324342
encoder_with_proj = EncoderWithProjection(model.encoder, model.joint)
325343
encoder_with_proj.eval()
344+
326345
programs["encoder"] = export(
327346
encoder_with_proj,
328347
(),
329348
kwargs={"audio_signal": audio_signal, "length": length},
330-
dynamic_shapes={"audio_signal": {2: Dim.AUTO}, "length": {}},
349+
dynamic_shapes={
350+
# Use Dim.AUTO - explicit bounds fail due to different size guards on different devices
351+
"audio_signal": {2: Dim.AUTO},
352+
"length": {},
353+
},
331354
strict=False,
332355
)
333356

@@ -553,7 +576,7 @@ def main():
553576

554577
et = lower_to_executorch(programs, metadata=metadata, backend=args.backend)
555578

556-
pte_path = os.path.join(args.output_dir, "parakeet_tdt.pte")
579+
pte_path = os.path.join(args.output_dir, "model.pte")
557580
print(f"\nSaving ExecuTorch program to: {pte_path}")
558581
with open(pte_path, "wb") as f:
559582
et.write_to_file(f)

0 commit comments

Comments
 (0)