@@ -485,24 +485,45 @@ utils::uvec3 ComputeGraph::create_local_wg_size(
485485 return config_.local_wg_size_override ;
486486 }
487487
488- utils::uvec3 local_group_size = {4 , 4 , 4 };
488+ // array containing axis index and global workgroup size
489+ std::pair<uint32_t , uint32_t > global_wg_size_desc[] = {
490+ {0u , global_wg_size[0 ]},
491+ {1u , global_wg_size[1 ]},
492+ {2u , global_wg_size[2 ]}};
493+
494+ // sort the global workgroup size in descending order
495+ if (global_wg_size_desc[0 ].second < global_wg_size_desc[1 ].second ) {
496+ std::swap (global_wg_size_desc[0 ], global_wg_size_desc[1 ]);
497+ }
498+ if (global_wg_size_desc[1 ].second < global_wg_size_desc[2 ].second ) {
499+ std::swap (global_wg_size_desc[1 ], global_wg_size_desc[2 ]);
500+ }
501+ if (global_wg_size_desc[0 ].second < global_wg_size_desc[1 ].second ) {
502+ std::swap (global_wg_size_desc[0 ], global_wg_size_desc[1 ]);
503+ }
489504
490- if (global_wg_size[2u ] == 1 ) {
491- if (global_wg_size[1u ] == 1 ) {
505+ utils::uvec3 local_group_size = {
506+ 8 ,
507+ std::max (1u , std::min (4u , global_wg_size_desc[1 ].second )),
508+ std::max (1u , std::min (2u , global_wg_size_desc[2 ].second ))};
509+
510+ if (global_wg_size_desc[2u ].second == 1 ) {
511+ if (global_wg_size_desc[1u ].second == 1 ) {
492512 local_group_size[0u ] = 64 ;
493513 local_group_size[1u ] = 1 ;
494- local_group_size[2u ] = 1 ;
495- } else if (global_wg_size[1u ] < 8 ) {
514+ } else if (global_wg_size_desc[1u ].second % 4 == 0 ) {
496515 local_group_size[0u ] = 16 ;
497516 local_group_size[1u ] = 4 ;
498- local_group_size[2u ] = 1 ;
499517 } else {
500- local_group_size[0u ] = 8 ;
501- local_group_size[1u ] = 8 ;
502- local_group_size[2u ] = 1 ;
518+ local_group_size[0u ] = 32 ;
519+ local_group_size[1u ] = 2 ;
503520 }
504521 }
505- return local_group_size;
522+
523+ return {
524+ local_group_size[global_wg_size_desc[0 ].first ],
525+ local_group_size[global_wg_size_desc[1 ].first ],
526+ local_group_size[global_wg_size_desc[2 ].first ]};
506527}
507528
508529utils::uvec3 ComputeGraph::create_local_wg_size (const ValueRef idx) {
0 commit comments