@@ -2502,6 +2502,9 @@ def pointwise(
25022502 triton_config_with_settings (
25032503 size_hints , TRITON_MAX_BLOCK ["X" ], waves_per_eu = 2
25042504 ),
2505+ triton_config_with_settings (
2506+ size_hints , 4096 # wrt: better than the max_block for some kernel
2507+ ),
25052508 * hinted_configs ,
25062509 ]
25072510 # Additional reduction configs appended for ROCm builds
@@ -2524,10 +2527,12 @@ def pointwise(
25242527 else :
25252528 configs = [
25262529 triton_config_with_settings (size_hints , 32 , 32 ),
2530+ triton_config_with_settings (size_hints , 64 , 32 ), # wrt: better for some kernels
25272531 triton_config_with_settings (size_hints , 64 , 64 ), # ~8% better for fp16
25282532 triton_config_with_settings (size_hints , 256 , 16 ),
25292533 triton_config_with_settings (size_hints , 16 , 256 ),
25302534 triton_config_with_settings (size_hints , 128 , 16 ), # wrt: +10% for some kernels
2535+ triton_config_with_settings (size_hints , 128 , 32 ), # wrt: ..additional 10% more
25312536 triton_config_with_settings (size_hints , 32 , 512 ), # wrt: +30% for some kernels
25322537 triton_config_with_settings (size_hints , bs , 1 ),
25332538 triton_config_with_settings (size_hints , 1 , bs ),
0 commit comments