Skip to content

Feature: add cli option for torch's built in matmul precision on supported graphics cards#413

Open
AeneasTews wants to merge 2 commits intojwohlwend:mainfrom
AeneasTews:main
Open

Feature: add cli option for torch's built in matmul precision on supported graphics cards#413
AeneasTews wants to merge 2 commits intojwohlwend:mainfrom
AeneasTews:main

Conversation

@AeneasTews
Copy link

When using a supported card and using preview build of pytorch (currently tested on version 2.8.0.dev20250616+cu128 and NVIDIA 5070 Ti) pytorch informs about the availability of matmulprecision which results in drastically improved runtimes when using high or medium instead of highest setting. Improvements can be up to 100% faster. This commit includes a command line option to toggle this based on user preference, default is highest. Keeping the default at highest should not cause any compatibility issues, as this is also the current default.

When using a supported card and using preview build of pytorch (currently tested on version 2.8.0.dev20250616+cu128 and NVIDIA 5070 Ti) pytorch informs about the availability of matmulprecision which results in drastically improved runtimes when using high or medium instead of highest setting. This commit includes a command line option to toggle this based on user preference, default is highest. Keeping the default at highest should not cause any compatibility issues, as this is also the current default.
@xavierholt
Copy link

Here's another vote for this! I almost made the same PR, but then I saw that someone beat me to it... I've edited this manually in the past, and I've seen a significant speedup going from highest to high (and no appreciable difference between high and medium). This was for Boltz1 on A100s, if I recall correctly.

It might also be worth checking the warnings filter as part of this. There's a call to filterwarnings() just above the call to set_float32_matmul_precision() that looks like it should hide the following error message, but it's still showing up.

You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision

@jwohlwend
Copy link
Owner

We've found in the past that using high or medium can hurt performance, so I m not super eager to incentivize users to do this. It's not just a question of card compatibility.

@xavierholt
Copy link

@jwohlwend I don't see a problem if it's done in the way this PR does it: keep the default at highest, but add a command line option so that people who know what they're doing (which hopefully translates to "people who read the documentation and run benchmarks") can get results cheaper and faster, if their systems support it. Unless there's an accuracy problem when running with lower precision that I'm unaware of?

The main thing that would incentivize people to drop the accuracy level is the big "you're not fully using your GPU" warning message that PyTorch prints out, but that's not touched by this PR. As I mentioned above, it seems like there's code to suppress that message, but I still see it, even with the latest release.

@jwohlwend
Copy link
Owner

No that's my point, we've observed accuracy issues with TF32

@xavierholt
Copy link

Oh! I thought you meant "performance" in the time/efficiency sense. If there are accuracy problems then I agree that this is a lot more dubious.

@AeneasTews
Copy link
Author

@jwohlwend thank you very much for your responses, would you be able to provide me with the tests that you performend to determine accuracy deterioration when using different levels of precision? Thank you very much for your help! Best regards!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants