@@ -2415,20 +2415,19 @@ def pointwise(
24152415 triton_config_with_settings (
24162416 size_hints , TRITON_MAX_BLOCK ["X" ], waves_per_eu = 2
24172417 ),
2418- triton_config_with_settings (
2419- size_hints , 4096 # wrt: better than the max_block for some kernel
2420- ),
24212418 * hinted_configs ,
24222419 ]
24232420 # Additional reduction configs appended for ROCm builds
24242421 if torch .version .hip :
2425- configs .append (triton_config_with_settings (
2426- size_hints ,
2427- 2048 ,
2428- num_warps = 8 ,
2429- num_stages = 2 ,
2430- waves_per_eu = 1
2431- )) # 20% improvement
2422+ configs += [
2423+ triton_config_with_settings (size_hints , 2048 , num_warps = 8 , num_stages = 2 , waves_per_eu = 1 ), # 20% improvement # .. in where?
2424+ triton_config_with_settings (size_hints , 4096 ), # wrt1: better than the max_block for some kernel
2425+ triton_config_with_settings (size_hints , 128 , num_warps = 2 , num_stages = 2 , waves_per_eu = 1 ),
2426+ # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
2427+ # triton_poi_fused_index_put_new_zeros_45
2428+ # triton_poi_fused_index_put_new_zeros_49
2429+ # triton_poi_fused_index_put_new_zeros_54
2430+ ]
24322431 if len (size_hints ) == 2 :
24332432 if (
24342433 disable_pointwise_autotuning (inductor_meta ) # or tile_hint == TileHint.SQUARE
@@ -2440,17 +2439,24 @@ def pointwise(
24402439 else :
24412440 configs = [
24422441 triton_config_with_settings (size_hints , 32 , 32 ),
2443- triton_config_with_settings (size_hints , 64 , 32 ), # wrt: better for some kernels
24442442 triton_config_with_settings (size_hints , 64 , 64 ), # ~8% better for fp16
2445- triton_config_with_settings (size_hints , 256 , 16 ),
2443+ triton_config_with_settings (size_hints , 256 , 16 ),
24462444 triton_config_with_settings (size_hints , 16 , 256 ),
2447- triton_config_with_settings (size_hints , 128 , 16 ), # wrt: +10% for some kernels
2448- triton_config_with_settings (size_hints , 128 , 32 ), # wrt: ..additional 10% more
2449- triton_config_with_settings (size_hints , 32 , 512 ), # wrt: +30% for some kernels
24502445 triton_config_with_settings (size_hints , bs , 1 ),
24512446 triton_config_with_settings (size_hints , 1 , bs ),
24522447 * hinted_configs ,
24532448 ]
2449+ if torch .version .hip :
2450+ configs += [ # add here
2451+ ]
2452+ # bypass triton_config_with_settings -> triton_config logic
2453+ if "x" in size_hints and "y" in size_hints :
2454+ configs += [
2455+ Config ({"XBLOCK" : 512 , "YBLOCK" : 8 }, num_warps = 8 ), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
2456+ Config ({"XBLOCK" : 32 , "YBLOCK" : 128 }, num_warps = 4 ), # wrt2: 570us : triton_poi_fused_add_transpose_view_52
2457+ Config ({"XBLOCK" :64 , "YBLOCK" : 32 }, num_warps = 8 ), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
2458+ ]
2459+
24542460 if len (size_hints ) == 3 :
24552461 if disable_pointwise_autotuning (inductor_meta ):
24562462 configs = [triton_config_with_settings (size_hints , 16 , 16 , 16 )]
@@ -2583,9 +2589,14 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False, wa
25832589 if torch .version .hip :
25842590 result_configs .extend ([
25852591 make_config (1024 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 2 ),
2586- make_config (512 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 )
2592+ make_config (512 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ),
2593+ make_config (128 , 4 , num_warps = 2 , num_stages = 1 , waves_per_eu = 1 ), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
2594+ make_config (1 , 512 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
2595+ make_config (1 , 4096 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
2596+ make_config (64 , 128 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
2597+ make_config (2 , 2048 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
25872598 ])
2588-
2599+
25892600 return result_configs
25902601
25912602
0 commit comments