@@ -1945,6 +1945,15 @@ def pointwise(
19451945 size_hints , 2048 , num_warps = 8 , num_stages = 2 , waves_per_eu = 1
19461946 )
19471947 ) # 20% improvement
1948+ configs += [
1949+ triton_config_with_settings (size_hints , 2048 , num_warps = 8 , num_stages = 2 , waves_per_eu = 1 ), # 20% improvement # .. in where?
1950+ triton_config_with_settings (size_hints , 4096 ), # wrt1: better than the max_block for some kernel
1951+ triton_config_with_settings (size_hints , 128 , num_warps = 2 , num_stages = 2 , waves_per_eu = 1 ),
1952+ # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
1953+ # triton_poi_fused_index_put_new_zeros_45
1954+ # triton_poi_fused_index_put_new_zeros_49
1955+ # triton_poi_fused_index_put_new_zeros_54
1956+ ]
19481957 if len (size_hints ) == 2 :
19491958 # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
19501959 # ROCm has observed improvement by diverging here
@@ -1963,7 +1972,7 @@ def pointwise(
19631972 size_hints , 64 , 32
19641973 ), # better for some kernels
19651974 triton_config_with_settings (size_hints , 64 , 64 ), # ~8% better for fp16
1966- triton_config_with_settings (size_hints , 256 , 16 ),
1975+ triton_config_with_settings (size_hints , 256 , 16 ),
19671976 triton_config_with_settings (size_hints , 16 , 256 ),
19681977 triton_config_with_settings (
19691978 size_hints , 128 , 16
@@ -1976,6 +1985,17 @@ def pointwise(
19761985 triton_config_with_settings (size_hints , 1 , bs ),
19771986 * hinted_configs ,
19781987 ]
1988+ if torch .version .hip :
1989+ configs += [ # add here
1990+ ]
1991+ # bypass triton_config_with_settings -> triton_config logic
1992+ if "x" in size_hints and "y" in size_hints :
1993+ configs += [
1994+ Config ({"XBLOCK" : 512 , "YBLOCK" : 8 }, num_warps = 8 ), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
1995+ Config ({"XBLOCK" : 32 , "YBLOCK" : 128 }, num_warps = 4 ), # wrt2: 570us : triton_poi_fused_add_transpose_view_52
1996+ Config ({"XBLOCK" :64 , "YBLOCK" : 32 }, num_warps = 8 ), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
1997+ ]
1998+
19791999 if len (size_hints ) == 3 :
19802000 if disable_pointwise_autotuning (inductor_meta ):
19812001 configs = [triton_config_with_settings (size_hints , 16 , 16 , 16 )]
@@ -2188,8 +2208,13 @@ def outer_config_opt():
21882208 [
21892209 make_config (1024 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 2 ),
21902210 make_config (512 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ),
2191- ]
2192- )
2211+ make_config (128 , 4 , num_warps = 2 , num_stages = 1 , waves_per_eu = 1 ), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
2212+ make_config (1 , 512 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
2213+ make_config (1 , 4096 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
2214+ make_config (64 , 128 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
2215+ make_config (2 , 2048 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
2216+ ]
2217+ )
21932218
21942219 return result_configs
21952220
0 commit comments