@@ -2505,6 +2505,9 @@ def pointwise(
25052505 triton_config_with_settings (
25062506 size_hints , TRITON_MAX_BLOCK ["X" ], waves_per_eu = 2
25072507 ),
2508+ triton_config_with_settings (
2509+ size_hints , 4096 # wrt: better than the max_block for some kernel
2510+ ),
25082511 * hinted_configs ,
25092512 ]
25102513 if len (size_hints ) == 2 :
@@ -2518,10 +2521,12 @@ def pointwise(
25182521 else :
25192522 configs = [
25202523 triton_config_with_settings (size_hints , 32 , 32 ),
2524+ triton_config_with_settings (size_hints , 64 , 32 ), # wrt: better for some kernels
25212525 triton_config_with_settings (size_hints , 64 , 64 ), # ~8% better for fp16
25222526 triton_config_with_settings (size_hints , 256 , 16 ),
25232527 triton_config_with_settings (size_hints , 16 , 256 ),
25242528 triton_config_with_settings (size_hints , 128 , 16 ), # wrt: +10% for some kernels
2529+ triton_config_with_settings (size_hints , 128 , 32 ), # wrt: ..additional 10% more
25252530 triton_config_with_settings (size_hints , 32 , 512 ), # wrt: +30% for some kernels
25262531 triton_config_with_settings (size_hints , bs , 1 ),
25272532 triton_config_with_settings (size_hints , 1 , bs ),
0 commit comments