@@ -2526,6 +2526,15 @@ def pointwise(
25262526 size_hints , 2048 , num_warps = 8 , num_stages = 2 , waves_per_eu = 1
25272527 )
25282528 ) # 20% improvement
2529+ configs += [
2530+ triton_config_with_settings (size_hints , 2048 , num_warps = 8 , num_stages = 2 , waves_per_eu = 1 ), # 20% improvement # .. in where?
2531+ triton_config_with_settings (size_hints , 4096 ), # wrt1: better than the max_block for some kernel
2532+ triton_config_with_settings (size_hints , 128 , num_warps = 2 , num_stages = 2 , waves_per_eu = 1 ),
2533+ # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
2534+ # triton_poi_fused_index_put_new_zeros_45
2535+ # triton_poi_fused_index_put_new_zeros_49
2536+ # triton_poi_fused_index_put_new_zeros_54
2537+ ]
25292538 if len (size_hints ) == 2 :
25302539 # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
25312540 # ROCm has observed improvement by diverging here
@@ -2544,7 +2553,7 @@ def pointwise(
25442553 size_hints , 64 , 32
25452554 ), # better for some kernels
25462555 triton_config_with_settings (size_hints , 64 , 64 ), # ~8% better for fp16
2547- triton_config_with_settings (size_hints , 256 , 16 ),
2556+ triton_config_with_settings (size_hints , 256 , 16 ),
25482557 triton_config_with_settings (size_hints , 16 , 256 ),
25492558 triton_config_with_settings (
25502559 size_hints , 128 , 16
@@ -2557,6 +2566,17 @@ def pointwise(
25572566 triton_config_with_settings (size_hints , 1 , bs ),
25582567 * hinted_configs ,
25592568 ]
2569+ if torch .version .hip :
2570+ configs += [ # add here
2571+ ]
2572+ # bypass triton_config_with_settings -> triton_config logic
2573+ if "x" in size_hints and "y" in size_hints :
2574+ configs += [
2575+ Config ({"XBLOCK" : 512 , "YBLOCK" : 8 }, num_warps = 8 ), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
2576+ Config ({"XBLOCK" : 32 , "YBLOCK" : 128 }, num_warps = 4 ), # wrt2: 570us : triton_poi_fused_add_transpose_view_52
2577+ Config ({"XBLOCK" :64 , "YBLOCK" : 32 }, num_warps = 8 ), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
2578+ ]
2579+
25602580 if len (size_hints ) == 3 :
25612581 if disable_pointwise_autotuning (inductor_meta ):
25622582 configs = [triton_config_with_settings (size_hints , 16 , 16 , 16 )]
@@ -2763,8 +2783,13 @@ def outer_config_opt():
27632783 [
27642784 make_config (1024 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 2 ),
27652785 make_config (512 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ),
2766- ]
2767- )
2786+ make_config (128 , 4 , num_warps = 2 , num_stages = 1 , waves_per_eu = 1 ), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
2787+ make_config (1 , 512 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
2788+ make_config (1 , 4096 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
2789+ make_config (64 , 128 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
2790+ make_config (2 , 2048 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
2791+ ]
2792+ )
27682793
27692794 return result_configs
27702795
0 commit comments