@@ -2415,6 +2415,9 @@ def pointwise(
24152415 triton_config_with_settings (
24162416 size_hints , TRITON_MAX_BLOCK ["X" ], waves_per_eu = 2
24172417 ),
2418+ triton_config_with_settings (
2419+ size_hints , 4096 # wrt: better than the max_block for some kernel
2420+ ),
24182421 * hinted_configs ,
24192422 ]
24202423 # Additional reduction configs appended for ROCm builds
@@ -2437,10 +2440,12 @@ def pointwise(
24372440 else :
24382441 configs = [
24392442 triton_config_with_settings (size_hints , 32 , 32 ),
2443+ triton_config_with_settings (size_hints , 64 , 32 ), # wrt: better for some kernels
24402444 triton_config_with_settings (size_hints , 64 , 64 ), # ~8% better for fp16
24412445 triton_config_with_settings (size_hints , 256 , 16 ),
24422446 triton_config_with_settings (size_hints , 16 , 256 ),
24432447 triton_config_with_settings (size_hints , 128 , 16 ), # wrt: +10% for some kernels
2448+ triton_config_with_settings (size_hints , 128 , 32 ), # wrt: ..additional 10% more
24442449 triton_config_with_settings (size_hints , 32 , 512 ), # wrt: +30% for some kernels
24452450 triton_config_with_settings (size_hints , bs , 1 ),
24462451 triton_config_with_settings (size_hints , 1 , bs ),
0 commit comments