Skip to content

Conversation

@cbensimon
Copy link
Contributor

@cbensimon cbensimon commented Sep 26, 2025

I recently noticed that we are spending a non-negligible amount of time in version.parse when running pipelines (between 50ms and 150ms per step for the QwenImageEdit pipeline on a ZeroGPU Space for instance, which in this case represents a significant amount of the actual compute). The calls to those version checks originate from:

Maybe that the issue can otherwise be solved from root (why do we need to unwrap the modules at each call?)

py-spy top results (QwenImageEdit H200 half, 28 steps) :

  %Own   %Total  OwnTime  TotalTime  Function (filename)                                                                                   
 18.00%  18.00%    8.76s     8.76s   silu (/usr/local/lib/python3.10/site-packages/torch/nn/functional.py)
 13.00%  13.00%    6.09s     6.13s   forward (/usr/local/lib/python3.10/site-packages/torch/nn/modules/linear.py)
  6.00%   6.00%    3.62s     3.82s   forward (/usr/local/lib/python3.10/site-packages/diffusers/models/normalization.py)
 11.00%  20.00%    3.49s     5.35s   __init__ (/usr/local/lib/python3.10/site-packages/packaging/version.py)
  1.00%   1.00%    1.67s     1.67s   apply_rotary_emb_qwen (/usr/local/lib/python3.10/site-packages/diffusers/models/transformers/transform
  3.00%   3.00%    1.06s     1.06s   _modulate (/usr/local/lib/python3.10/site-packages/diffusers/models/transformers/transformer_qwenimage
  3.00%   3.00%   0.980s     1.01s   _cmpkey (/usr/local/lib/python3.10/site-packages/packaging/version.py)
  3.00%   3.00%   0.720s    0.720s   __getattr__ (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py)
  2.00%   2.00%   0.660s    0.660s   named_modules (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py)
  3.00%  21.00%   0.550s    11.19s   __call__ (/usr/local/lib/python3.10/site-packages/diffusers/models/transformers/transformer_qwenimage.
  1.00%  12.00%   0.500s     7.59s   _set_context (/usr/local/lib/python3.10/site-packages/diffusers/hooks/hooks.py)
  5.00%   5.00%   0.500s    0.500s   <genexpr> (/usr/local/lib/python3.10/site-packages/packaging/version.py)
  0.00%  45.00%   0.360s    23.44s   forward (/usr/local/lib/python3.10/site-packages/diffusers/models/transformers/transformer_qwenimage.p
 18.00%  18.00%   0.280s    0.280s   _conv_forward (/usr/local/lib/python3.10/site-packages/torch/nn/modules/conv.py)
  1.00%  66.00%   0.240s    24.06s   _call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py)
  0.00%   0.00%   0.210s    0.220s   layer_norm (/usr/local/lib/python3.10/site-packages/torch/nn/functional.py)
  2.00%  11.00%   0.200s     2.89s   compare_versions (/usr/local/lib/python3.10/site-packages/diffusers/utils/import_utils.py)
  0.00%   0.00%   0.180s    0.180s   <lambda> (<string>)
  0.00%  22.00%   0.170s     5.99s   is_compiled_module (/usr/local/lib/python3.10/site-packages/diffusers/utils/torch_utils.py)
  0.00%  20.00%   0.160s     5.51s   parse (/usr/local/lib/python3.10/site-packages/packaging/version.py)
  0.00%  22.00%   0.150s     5.91s   is_torch_version (/usr/local/lib/python3.10/site-packages/diffusers/utils/import_utils.py)
  0.00%  66.00%   0.120s    24.07s   _wrapped_call_impl (/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py)
  1.00%   1.00%   0.110s    0.110s   _parse_letter_version (/usr/local/lib/python3.10/site-packages/packaging/version.py)
  0.00%   0.00%   0.100s    0.100s   _native_attention (/usr/local/lib/python3.10/site-packages/diffusers/models/attention_dispatch.py)
  0.00%   0.00%   0.100s    0.230s   forward (/usr/local/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
  0.00%   0.00%   0.080s    0.080s   _extract_masked_hidden (/usr/local/lib/python3.10/site-packages/diffusers/pipelines/qwenimage/pipeline
  0.00%   0.00%   0.060s    0.060s   _parse_local_version (/usr/local/lib/python3.10/site-packages/packaging/version.py)
  0.00%   0.00%   0.060s    0.060s   dropout (/usr/local/lib/python3.10/site-packages/torch/nn/functional.py)
  0.00%   0.00%   0.050s    0.470s   forward (/usr/local/lib/python3.10/site-packages/diffusers/models/attention.py)
  0.00%   0.00%   0.040s    0.040s   get_placeholder_mask (/usr/local/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_
  0.00%  21.00%   0.040s    11.32s   forward (/usr/local/lib/python3.10/site-packages/diffusers/models/attention_processor.py)
  0.00%   0.00%   0.040s    0.040s   __lt__ (/usr/local/lib/python3.10/site-packages/packaging/version.py)
  0.00%   0.00%   0.030s    0.030s   unflatten (/usr/local/lib/python3.10/site-packages/torch/_tensor.py)
  0.00%   0.00%   0.030s    0.240s   dispatch_attention_fn (/usr/local/lib/python3.10/site-packages/diffusers/models/attention_dispatch.py)
  0.00%   0.00%   0.030s    0.030s   norm (/usr/local/lib/python3.10/site-packages/torch/functional.py)
  0.00%  22.00%   0.030s     6.02s   unwrap_module (/usr/local/lib/python3.10/site-packages/diffusers/utils/torch_utils.py)
  1.00%   1.00%   0.030s    0.060s   normalize (/usr/local/lib/python3.10/site-packages/torch/nn/functional.py)
  3.00% 100.00%   0.030s    32.00s   __call__ (/usr/local/lib/python3.10/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edi
  0.00%  12.00%   0.030s     5.38s   forward (/usr/local/lib/python3.10/site-packages/peft/tuners/lora/layer.py)

UPDATE:

  • stack trace of the version check
Thread 95468 (active+gil): "ThreadPoolExecutor-7_0"
    __init__ (/usr/local/lib/python3.10/site-packages/packaging/version.py:205)
    parse (/usr/local/lib/python3.10/site-packages/packaging/version.py:56)
    is_torch_version (/usr/local/lib/python3.10/site-packages/diffusers/utils/import_utils.py:686)
    is_compiled_module (/usr/local/lib/python3.10/site-packages/diffusers/utils/torch_utils.py:198)
    unwrap_module (/usr/local/lib/python3.10/site-packages/diffusers/utils/torch_utils.py:205)
    _set_context (/usr/local/lib/python3.10/site-packages/diffusers/hooks/hooks.py:277)
    cache_context (/usr/local/lib/python3.10/site-packages/diffusers/models/cache_utils.py:126)
    __enter__ (/usr/local/lib/python3.10/contextlib.py:135)
    __call__ (/usr/local/lib/python3.10/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py:816)
    decorate_context (/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py:120)
    infer (/home/user/app/app.py:224)
    run_task (/usr/local/lib/python3.10/site-packages/spaces/zero/wrappers.py:256)
    run (/usr/local/lib/python3.10/concurrent/futures/thread.py:58)
    _worker (/usr/local/lib/python3.10/concurrent/futures/thread.py:83)
    run (/usr/local/lib/python3.10/threading.py:953)
    _bootstrap_inner (/usr/local/lib/python3.10/threading.py:1016)
    _bootstrap (/usr/local/lib/python3.10/threading.py:973)

I recently noticed that we are spending a non-negligible amount of time in `version.parse` when running pipelines (approx. ~50ms per step for the QwenImageEdit pipeline on a ZeroGPU Space for instance, which in this case represents almost 10% of the actual compute). The calls to those version checks originate from:
- https://github.com/huggingface/diffusers/blob/4588bbeb4229fd307119257e273a424b370573b1/src/diffusers/hooks/hooks.py#L277

Maybe that the issue can otherwise be solved from root (why do we need to unwrap the modules at each call?) or maybe that my particular setup triggered this? (I patched the forward method at the blocks level but I don't feel like it has an incidence over _set_context)
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

Interesting PR. Do we know why something like this would show up?

0.00%  12.00%   0.030s     5.38s   forward (/usr/local/lib/python3.10/site-packages/peft/tuners/lora/layer.py)

@cbensimon
Copy link
Contributor Author

cbensimon commented Sep 26, 2025

Interesting PR. Do we know why something like this would show up?

0.00%  12.00%   0.030s     5.38s   forward (/usr/local/lib/python3.10/site-packages/peft/tuners/lora/layer.py)

Indeed, I forgot to mention that the model had LoRA weights loaded. I'll check without and push the results + a minimal code snippet to reproduce

@cbensimon
Copy link
Contributor Author

Small update:

After a quick test on LambdaLabs, it looks like even though a lot of the time is spent in versions checking (18% on QwenImageEdit, LambdaLabs H100), it does not matter in the end (same end-to-end pipeline duration), probably because CUDA calls happens asynchronously with Python code.

On ZeroGPU there was a real difference though (17% faster when disabling unwrap_module function). I'll investigate more in the coming days / weeks

@cbensimon
Copy link
Contributor Author

After some measurements I came to the conclusion that performance gains are hard to predict because it depends on whether PyTorch overhead (Python side) is a limiting factor or not. On fast-GPU + poor-CPU (slow CPU or CPU stressed by other apps) environments, speed-ups will be visible, otherwise there won't be any.

That said I feel like caching the version checks is still incrementally better and makes diffusers more CPU friendly. @sayakpaul do you see any potential unwanted drawbacks?

@sayakpaul
Copy link
Member

I don't! Would still just need to see if it interferes with torch.compile. So, would you be able to run the compilation tests, at least for Flux?

class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):

@cbensimon
Copy link
Contributor Author

All good!

Test command + output:

(.venv) ubuntu@192-222-52-227:~/diffusers$ RUN_SLOW=true RUN_COMPILE=true python -m pytest -v tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests
================================================================================ test session starts ================================================================================
platform linux -- Python 3.10.12, pytest-8.4.2, pluggy-1.6.0 -- /home/ubuntu/diffusers/.venv/bin/python
cachedir: .pytest_cache
rootdir: /home/ubuntu/diffusers
configfile: pyproject.toml
plugins: timeout-2.4.0, xdist-3.8.0, requests-mock-1.10.0, anyio-4.11.0
collected 5 items                                                                                                                                                                   

tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_compile_on_different_shapes PASSED                                               [ 20%]
tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_compile_with_group_offloading PASSED                                             [ 40%]
tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_compile_works_with_aot PASSED                                                    [ 60%]
tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_torch_compile_recompilation_and_graph_break PASSED                               [ 80%]
tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_torch_compile_repeated_blocks PASSED                                             [100%]

================================================================================= warnings summary ==================================================================================
tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_compile_on_different_shapes
tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_compile_with_group_offloading
tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_torch_compile_recompilation_and_graph_break
tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_torch_compile_repeated_blocks
  /home/ubuntu/diffusers/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
    torch._dynamo.utils.warn_once(msg)

tests/models/transformers/test_models_transformer_flux.py::FluxTransformerCompileTests::test_compile_on_different_shapes
  /home/ubuntu/diffusers/.venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:282: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================== 5 passed, 5 warnings in 38.44s ===========================================================================

@sayakpaul sayakpaul requested a review from DN6 October 6, 2025 15:45
@cbensimon cbensimon merged commit cf4b97b into main Oct 6, 2025
13 of 14 checks passed
@vladmandic
Copy link
Contributor

would there be benefits to expand this to cover different is_available functions?
most of them are constant based, but they still go through several modules to unwind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants