-
Notifications
You must be signed in to change notification settings - Fork 76
Open
Description
Speech workloads appear to be ~5x slower in update then before.
Happens with pmap and jit.
Steps to Reproduce
in container run
python submission_runner.py --framework=jax --workload=librispeech_deepspeech --submission_path=reference_algorithms/qualification_baselines/external_tuning/jax_nadamw_target_setting.py --data_dir=/data/librispeech --num_tuning_trials=1 --experiment_dir=/experiment_runs --experiment_name=tests/regression_tests/adamw --overwrite=True --save_checkpoints=False --max_global_steps=10 --librispeech_tokenizer_vocab_path=/data/librispeech/spm_model.vocab --tuning_ruleset=external --tuning_search_space=reference_algorithms/qualification_baselines/external_tuning/tuning_search_space.json
Source or Possible Fix
Maybe a package update is resulting in a compilation difference?
Suspicious message in logs
2025-07-15 03:06:29.349878: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:
%reduce-window.10 = f32[256,500]{1,0} reduce-window(%broadcast.248, %constant.119), window={size=1x500 pad=0_0x499_0}, to_apply=%region_35.2517.clone, metadata={op_name="jit(_eval_step)/jit(cumsum)/LibriSpeechConformerWorkload.sequence_mask/reduce_window_sum" source_file="/algorithmic-efficiency/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py" source_line=244}
Metadata
Metadata
Assignees
Labels
No labels