Skip to content

Commit fa99a7d

Browse files
Fix cb.CUDAOptions usage for Triton 3.6.0
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
1 parent fbb16f4 commit fa99a7d

File tree

1 file changed

+5
-1
lines changed
  • transformer_engine/jax/triton_extensions

1 file changed

+5
-1
lines changed

transformer_engine/jax/triton_extensions/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import hashlib
3434
import os
3535
import warnings
36+
from packaging import version
3637
from typing import Any, Callable, Mapping
3738
import zlib
3839

@@ -274,13 +275,16 @@ def compile_triton(
274275
return _TRITON_KERNEL_CACHE[cache_key]
275276

276277
# Compile kernel
278+
cuda_option_kwargs = {}
279+
if version.parse(triton.__version__) <= version.parse("3.6.0"):
280+
cuda_option_kwargs["cluster_dims"] = (1, 1, 1)
277281
options = cb.CUDAOptions(
278282
num_warps=num_warps,
279283
num_stages=num_stages,
280284
num_ctas=num_ctas,
281-
cluster_dims=(1, 1, 1),
282285
debug=False,
283286
enable_fp_fusion=enable_fp_fusion,
287+
**cuda_option_kwargs
284288
)
285289

286290
# Mark constants as constexpr in signature

0 commit comments

Comments
 (0)