@@ -2415,15 +2415,18 @@ def pointwise(
24152415 triton_config_with_settings (
24162416 size_hints , TRITON_MAX_BLOCK ["X" ], waves_per_eu = 2
24172417 ),
2418- triton_config_with_settings (
2419- size_hints , 4096 # wrt: better than the max_block for some kernel
2420- ),
24212418 * hinted_configs ,
24222419 ]
24232420 # Additional reduction configs appended for ROCm builds
24242421 if torch .version .hip :
24252422 configs += [
2426- triton_config_with_settings (size_hints , 2048 , num_warps = 8 , num_stages = 2 , waves_per_eu = 1 ) # 20% improvement # ..where?
2423+ triton_config_with_settings (size_hints , 2048 , num_warps = 8 , num_stages = 2 , waves_per_eu = 1 ), # 20% improvement # .. in where?
2424+ triton_config_with_settings (size_hints , 4096 ), # wrt1: better than the max_block for some kernel
2425+ triton_config_with_settings (size_hints , 128 , num_warps = 2 , num_stages = 2 , waves_per_eu = 1 ),
2426+ # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
2427+ # triton_poi_fused_index_put_new_zeros_45
2428+ # triton_poi_fused_index_put_new_zeros_49
2429+ # triton_poi_fused_index_put_new_zeros_54
24272430 ]
24282431 if len (size_hints ) == 2 :
24292432 if (
@@ -2437,7 +2440,7 @@ def pointwise(
24372440 configs = [
24382441 triton_config_with_settings (size_hints , 32 , 32 ),
24392442 triton_config_with_settings (size_hints , 64 , 64 ), # ~8% better for fp16
2440- triton_config_with_settings (size_hints , 256 , 16 ),
2443+ triton_config_with_settings (size_hints , 256 , 16 ),
24412444 triton_config_with_settings (size_hints , 16 , 256 ),
24422445 triton_config_with_settings (size_hints , bs , 1 ),
24432446 triton_config_with_settings (size_hints , 1 , bs ),
@@ -2448,8 +2451,8 @@ def pointwise(
24482451 ]
24492452 # bypass triton_config_with_settings -> triton_config logic
24502453 if "x" in size_hints and "y" in size_hints :
2451- cfg = {"XBLOCK" : 32 , "YBLOCK" : 128 }
24522454 configs += [
2455+ Config ({"XBLOCK" : 512 , "YBLOCK" : 8 }, num_warps = 8 ), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
24532456 Config ({"XBLOCK" : 32 , "YBLOCK" : 128 }, num_warps = 4 ), # wrt2: 570us : triton_poi_fused_add_transpose_view_52
24542457 Config ({"XBLOCK" :64 , "YBLOCK" : 32 }, num_warps = 8 ), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
24552458 ]
@@ -2474,12 +2477,6 @@ def pointwise(
24742477
24752478 configs = _maybe_filter_configs_for_tma_restrictions (inductor_meta , configs )
24762479
2477- print ()
2478- print ("Pointwise will use following configs" )
2479- for config in configs :
2480- print (">" , config )
2481- print ()
2482-
24832480 return cached_autotune (
24842481 size_hints ,
24852482 configs ,
@@ -2699,12 +2696,6 @@ def reduction(
26992696 configs = _reduction_configs (size_hints = size_hints , inductor_meta = inductor_meta )
27002697 configs = _maybe_filter_configs_for_tma_restrictions (inductor_meta , configs )
27012698
2702- print ()
2703- print ("Reduction will use following configs" )
2704- for config in configs :
2705- print (">" , config )
2706- print ()
2707-
27082699 return cached_autotune (
27092700 size_hints ,
27102701 configs = configs ,
0 commit comments