@@ -1913,6 +1913,9 @@ def pointwise(
19131913 triton_config_with_settings (
19141914 size_hints , TRITON_MAX_BLOCK ["X" ], waves_per_eu = 2
19151915 ),
1916+ triton_config_with_settings (
1917+ size_hints , 4096 # wrt: better than the max_block for some kernel
1918+ ),
19161919 * hinted_configs ,
19171920 ]
19181921 # Additional reduction configs appended for ROCm builds
@@ -1935,10 +1938,12 @@ def pointwise(
19351938 else :
19361939 configs = [
19371940 triton_config_with_settings (size_hints , 32 , 32 ),
1941+ triton_config_with_settings (size_hints , 64 , 32 ), # wrt: better for some kernels
19381942 triton_config_with_settings (size_hints , 64 , 64 ), # ~8% better for fp16
19391943 triton_config_with_settings (size_hints , 256 , 16 ),
19401944 triton_config_with_settings (size_hints , 16 , 256 ),
19411945 triton_config_with_settings (size_hints , 128 , 16 ), # wrt: +10% for some kernels
1946+ triton_config_with_settings (size_hints , 128 , 32 ), # wrt: ..additional 10% more
19421947 triton_config_with_settings (size_hints , 32 , 512 ), # wrt: +30% for some kernels
19431948 triton_config_with_settings (size_hints , bs , 1 ),
19441949 triton_config_with_settings (size_hints , 1 , bs ),
0 commit comments