@@ -540,43 +540,34 @@ def _scale_mm_configs(
540
540
541
541
return scaled_configs
542
542
543
- def _exceed_available_shared_memeory (
544
- self , gemm_config : BaseConfig , dtype_size : int
545
- ) -> bool :
546
- try :
547
- if dtype_size <= 0 :
548
- return False
549
-
550
- device = torch .cuda .current_device ()
551
- props = torch .cuda .get_device_properties (device )
552
- if not hasattr (props , "shared_memory_per_block_optin" ):
553
- return False
554
- sm_available = props .shared_memory_per_block_optin # type: ignore[attr-defined]
555
- shared_mem_accum = dtype_size * (
556
- gemm_config .block_m * gemm_config .block_k
557
- + gemm_config .block_n * gemm_config .block_k
558
- )
559
- return shared_mem_accum * gemm_config .num_stages > sm_available
560
- except Exception :
561
- return False
562
-
563
543
def _prune_exhaustive_configs (
564
544
self ,
565
545
configs : list [BaseConfig ],
566
546
dtype_size : int ,
567
547
) -> list [BaseConfig ]:
548
+ import torch
549
+
568
550
pruned_configs = []
569
551
for gemm_config in configs :
570
- # Will use more shared memory than available
571
- if self ._exceed_available_shared_memeory (gemm_config , dtype_size ):
572
- continue
573
-
552
+ device = torch .cuda .current_device ()
553
+ props = torch .cuda .get_device_properties (device )
554
+ sm_available = props .shared_memory_per_block_optin # type: ignore[attr-defined]
574
555
NUM_REG = 255
556
+
575
557
acc_regs = math .ceil (
576
558
gemm_config .block_m * gemm_config .block_n / (gemm_config .num_warps * 32 )
577
559
)
560
+
561
+ shared_mem_accum = dtype_size * (
562
+ gemm_config .block_m * gemm_config .block_k
563
+ + gemm_config .block_n * gemm_config .block_k
564
+ )
565
+
566
+ # Will use more shared memory than available
567
+ if shared_mem_accum * gemm_config .num_stages > sm_available :
568
+ continue
578
569
# Lower bound for register spillage, if exceeds the kernel will certainly spill
579
- if acc_regs > NUM_REG :
570
+ elif acc_regs > NUM_REG :
580
571
continue
581
572
582
573
pruned_configs .append (gemm_config )
@@ -608,15 +599,6 @@ def preprocess_mm_configs(
608
599
scaled_configs = self ._scale_mm_configs (
609
600
m , n , k , configs , scale , has_int8_tensor , exclude
610
601
)
611
-
612
- # Filter out configs that require more shared memory than is available.
613
- if dtype_size > 0 and config .max_autotune_prune_choices_based_on_shared_mem :
614
- scaled_configs = [
615
- c
616
- for c in scaled_configs
617
- if not self ._exceed_available_shared_memeory (c , dtype_size )
618
- ]
619
-
620
602
if config .max_autotune_gemm_search_space == "EXHAUSTIVE" :
621
603
assert dtype_size > 0 , "dtype_size must be provided for exhaustive search"
622
604
scaled_configs = self ._prune_exhaustive_configs (scaled_configs , dtype_size )
0 commit comments