@@ -2203,18 +2203,18 @@ def outer_config_opt():
22032203 make_config (64 , 4 , num_warps = 8 ),
22042204 ]
22052205
2206- if torch .version .hip :
2207- result_configs .extend (
2208- [
2209- make_config (1024 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 2 ),
2210- make_config (512 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ),
2211- make_config (128 , 4 , num_warps = 2 , num_stages = 1 , waves_per_eu = 1 ), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
2212- make_config (1 , 512 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
2213- make_config (1 , 4096 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
2214- make_config (64 , 128 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
2215- make_config (2 , 2048 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
2216- ]
2217- )
2206+ if torch .version .hip :
2207+ result_configs .extend (
2208+ [
2209+ make_config (1024 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 2 ),
2210+ make_config (512 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ),
2211+ make_config (128 , 4 , num_warps = 2 , num_stages = 1 , waves_per_eu = 1 ), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
2212+ make_config (1 , 512 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
2213+ make_config (1 , 4096 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
2214+ make_config (64 , 128 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
2215+ make_config (2 , 2048 , num_warps = 8 , num_stages = 1 , waves_per_eu = 1 ) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
2216+ ]
2217+ )
22182218
22192219 return result_configs
22202220
0 commit comments