@@ -646,15 +646,27 @@ def _get_exceeding_shared_memory_checker(
646646 If the device does not report available shared memory, returns None.
647647 """
648648
649+ from ..utils import get_gpu_shared_memory
650+
651+ sm_available = None
652+
649653 try :
650654 device = torch .cuda .current_device ()
651655 props = torch .cuda .get_device_properties (device )
652656 if not hasattr (props , "shared_memory_per_block_optin" ): # for NVidia GPUs
653657 return None
654658 sm_available = int (props .shared_memory_per_block_optin )
655659 except Exception :
656- # If CUDA is not available or properties cannot be queried, return None
657- return None
660+ pass
661+
662+ # ROCm specific logic to get shared memory
663+ if torch .version .hip and sm_available is None :
664+ try :
665+ sm_available = get_gpu_shared_memory ()
666+ if sm_available == 0 :
667+ return None
668+ except Exception :
669+ return None
658670
659671 # TODO make a BaseDeviceConfigHeuristics to handle different device configuration in its own implementation.
660672 def exceeds (gemm_config : BaseConfig , dtype_size : int ) -> bool :
@@ -1318,6 +1330,7 @@ def _finalize_mm_configs(
13181330 waves_per_eu ,
13191331 matrix_instr_nonkdim ,
13201332 kpack ,
1333+ conf .hint_override ,
13211334 )
13221335
13231336 # Check if gemm specific arg exists - add to key if does
@@ -1344,7 +1357,12 @@ def _finalize_mm_configs(
13441357 }
13451358 if group_m is not None :
13461359 kwargs ["GROUP_M" ] = group_m
1347- yield self .triton_config (** kwargs )
1360+
1361+ tc = self .triton_config (** kwargs )
1362+ # Preserve hint_override for multi-kernel support
1363+ if hasattr (conf , "hint_override" ) and conf .hint_override is not None :
1364+ tc .hint_override = conf .hint_override
1365+ yield tc
13481366
13491367 def get_flex_attn_fwd_configs (self , head_dim : int , dtype : Any ) -> list [FlexConfig ]:
13501368 flex_attn_fwd_configs : list [FlexConfig ] = []
@@ -1674,6 +1692,12 @@ def _convert_config_to_template_kwargs(
16741692 group_m = triton_config .kwargs .get ("GROUP_M" , 8 )
16751693 options_dict ["GROUP_M" ] = group_m
16761694
1695+ # Keep ROCm multi-kernel size bucket attached to the config
1696+ if torch .version .hip and "hint_override" not in options_dict :
1697+ hint_override = getattr (triton_config , "hint_override" , None )
1698+ if hint_override is not None :
1699+ options_dict ["hint_override" ] = hint_override
1700+
16771701 return options_dict
16781702
16791703 def _get_acc_type (self , dtype : torch .dtype ) -> str :
0 commit comments