@@ -1928,8 +1928,11 @@ def pointwise(
19281928 waves_per_eu = 1
19291929 )) # 20% improvement
19301930 if len (size_hints ) == 2 :
1931+ # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
1932+ # ROCm has observed improvement by diverging here
19311933 if (
1932- disable_pointwise_autotuning (inductor_meta ) # or tile_hint == TileHint.SQUARE
1934+ disable_pointwise_autotuning (inductor_meta )
1935+ or (torch .version .hip is None and tile_hint == TileHint .SQUARE )
19331936 ) and not (
19341937 inductor_meta .get ("max_autotune" )
19351938 or inductor_meta .get ("max_autotune_pointwise" )
@@ -1938,13 +1941,13 @@ def pointwise(
19381941 else :
19391942 configs = [
19401943 triton_config_with_settings (size_hints , 32 , 32 ),
1941- triton_config_with_settings (size_hints , 64 , 32 ), # wrt: better for some kernels
1944+ triton_config_with_settings (size_hints , 64 , 32 ), # better for some kernels
19421945 triton_config_with_settings (size_hints , 64 , 64 ), # ~8% better for fp16
19431946 triton_config_with_settings (size_hints , 256 , 16 ),
19441947 triton_config_with_settings (size_hints , 16 , 256 ),
1945- triton_config_with_settings (size_hints , 128 , 16 ), # wrt: +10% for some kernels
1946- triton_config_with_settings (size_hints , 128 , 32 ), # wrt: .. additional 10% more
1947- triton_config_with_settings (size_hints , 32 , 512 ), # wrt: +30% for some kernels
1948+ triton_config_with_settings (size_hints , 128 , 16 ), # +10% for some kernels
1949+ triton_config_with_settings (size_hints , 128 , 32 ), # additional 10% more
1950+ triton_config_with_settings (size_hints , 32 , 512 ), # +30% for some kernels
19481951 triton_config_with_settings (size_hints , bs , 1 ),
19491952 triton_config_with_settings (size_hints , 1 , bs ),
19501953 * hinted_configs ,
0 commit comments