@@ -2517,8 +2517,11 @@ def pointwise(
25172517 waves_per_eu = 1
25182518 )) # 20% improvement
25192519 if len (size_hints ) == 2 :
2520+ # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
2521+ # ROCm has observed improvement by diverging here
25202522 if (
2521- disable_pointwise_autotuning (inductor_meta ) # or tile_hint == TileHint.SQUARE
2523+ disable_pointwise_autotuning (inductor_meta )
2524+ or (torch .version .hip is None and tile_hint == TileHint .SQUARE )
25222525 ) and not (
25232526 inductor_meta .get ("max_autotune" )
25242527 or inductor_meta .get ("max_autotune_pointwise" )
@@ -2527,13 +2530,13 @@ def pointwise(
25272530 else :
25282531 configs = [
25292532 triton_config_with_settings (size_hints , 32 , 32 ),
2530- triton_config_with_settings (size_hints , 64 , 32 ), # wrt: better for some kernels
2533+ triton_config_with_settings (size_hints , 64 , 32 ), # better for some kernels
25312534 triton_config_with_settings (size_hints , 64 , 64 ), # ~8% better for fp16
25322535 triton_config_with_settings (size_hints , 256 , 16 ),
25332536 triton_config_with_settings (size_hints , 16 , 256 ),
2534- triton_config_with_settings (size_hints , 128 , 16 ), # wrt: +10% for some kernels
2535- triton_config_with_settings (size_hints , 128 , 32 ), # wrt: .. additional 10% more
2536- triton_config_with_settings (size_hints , 32 , 512 ), # wrt: +30% for some kernels
2537+ triton_config_with_settings (size_hints , 128 , 16 ), # +10% for some kernels
2538+ triton_config_with_settings (size_hints , 128 , 32 ), # additional 10% more
2539+ triton_config_with_settings (size_hints , 32 , 512 ), # +30% for some kernels
25372540 triton_config_with_settings (size_hints , bs , 1 ),
25382541 triton_config_with_settings (size_hints , 1 , bs ),
25392542 * hinted_configs ,
0 commit comments