File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
transformer_engine/jax/triton_extensions Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change 3333import hashlib
3434import os
3535import warnings
36+ from packaging import version
3637from typing import Any , Callable , Mapping
3738import zlib
3839
@@ -274,13 +275,16 @@ def compile_triton(
274275 return _TRITON_KERNEL_CACHE [cache_key ]
275276
276277 # Compile kernel
278+ cuda_option_kwargs = {}
279+ if version .parse (triton .__version__ ) <= version .parse ("3.6.0" ):
280+ cuda_option_kwargs ["cluster_dims" ] = (1 , 1 , 1 )
277281 options = cb .CUDAOptions (
278282 num_warps = num_warps ,
279283 num_stages = num_stages ,
280284 num_ctas = num_ctas ,
281- cluster_dims = (1 , 1 , 1 ),
282285 debug = False ,
283286 enable_fp_fusion = enable_fp_fusion ,
287+ ** cuda_option_kwargs
284288 )
285289
286290 # Mark constants as constexpr in signature
You can’t perform that action at this time.
0 commit comments