@@ -522,6 +522,26 @@ static std::vector<NnVulkanBatchInfo> buildBatchInfo(NnOpConfig *opConfig, NnVul
522522 return offset;
523523}
524524
525+ static NnUint roundUpPow2 (NnUint n, NnUint min, NnUint max) {
526+ NnUint p = 1 ;
527+ while (p << 1 <= n) p <<= 1 ;
528+ if (p < n) p <<= 1 ;
529+ if (p < min) p = min;
530+ if (p > max) p = max;
531+ return p;
532+ }
533+
534+ static uint32_t resolveShaderNThreads (const NnOpConfig *opConfig, const NnSize2D inputSize) {
535+ if (opConfig->code == OP_MATMUL) {
536+ if (opConfig->weightSize .floatType == F_Q40) {
537+ constexpr NnUint maxThreads = 256 ; // Shader constant
538+ NnUint t = roundUpPow2 (inputSize.x / (Q40_BLOCK_SIZE * 2 ), 32 , maxThreads);
539+ return t;
540+ }
541+ }
542+ return 0 ;
543+ }
544+
525545static void resolveShaderGroups (const NnOpConfig *opConfig, const NnUint batchSize, NnUint *groupCount, const NnSize2D inputSize, const NnSize2D outputSize) {
526546 groupCount[0 ] = 1 ;
527547 groupCount[1 ] = batchSize;
@@ -531,7 +551,7 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
531551 if (outputSize.floatType == F_Q80) {
532552 groupCount[2 ] = outputSize.x / Q80_BLOCK_SIZE;
533553 } else {
534- constexpr NnUint chunkSize = 4 ;
554+ constexpr NnUint chunkSize = 4 ; // Shader constant
535555 groupCount[2 ] = outputSize.x / chunkSize;
536556 }
537557 } else if (opConfig->code == OP_MERGE_ADD) {
@@ -542,9 +562,8 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
542562 }
543563 } else if (opConfig->code == OP_MATMUL) {
544564 if (opConfig->weightSize .floatType == F_Q40) {
545- // Must be synced with the shader
546- constexpr NnUint tileSizeN = 2 ;
547- constexpr NnUint tileSizeD = 8 ;
565+ constexpr NnUint tileSizeN = 2 ; // Shader constant
566+ constexpr NnUint tileSizeD = 8 ; // Shader constant
548567 const NnUint blockSize = getBlockSize (opConfig->weightSize .floatType );
549568 assert (opConfig->weightSize .y % (tileSizeN * blockSize) == 0 );
550569 assert (opConfig->weightSize .x % tileSizeD == 0 );
@@ -564,7 +583,7 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
564583 opConfig->code == OP_SILU ||
565584 opConfig->code == OP_SHIFT
566585 ) {
567- constexpr NnUint chunkSize = 4 ;
586+ constexpr NnUint chunkSize = 4 ; // Shader constant
568587 assert (outputSize.x % chunkSize == 0 );
569588 groupCount[2 ] = outputSize.x / chunkSize;
570589 }
@@ -663,12 +682,10 @@ NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanD
663682
664683 std::vector<vk::PipelineShaderStageCreateInfo> shaderCreateInfos (segmentConfig->nOps );
665684
666- constexpr NnUint maxConsts = 3 ;
667- std::vector<NnUint> nConsts (segmentConfig->nOps );
668- std::vector<int > consts (segmentConfig->nOps * maxConsts);
669685 std::vector<vk::SpecializationInfo> specInfos (segmentConfig->nOps );
670- std::vector<vk::SpecializationMapEntry> specMapEntries (segmentConfig->nOps * maxConsts);
671-
686+ std::vector<vk::SpecializationMapEntry> specEntries (segmentConfig->nOps );
687+ std::vector<uint32_t > nThreads (segmentConfig->nOps );
688+
672689 for (NnUint opIndex = 0 ; opIndex < segmentConfig->nOps ; opIndex++) {
673690 NnOpConfig *opConfig = &segmentConfig->ops [opIndex];
674691 NnSize2D inputSize = data->resolveBufferSize (&opConfig->input );
@@ -690,12 +707,17 @@ NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanD
690707 code.data ()
691708 );
692709
710+ nThreads[opIndex] = resolveShaderNThreads (opConfig, inputSize);
711+ specEntries[opIndex] = vk::SpecializationMapEntry (0 , 0 , sizeof (uint32_t ));
712+ specInfos[opIndex] = vk::SpecializationInfo (1 , &specEntries[opIndex], sizeof (uint32_t ), &nThreads[opIndex]);
713+
693714 vk::ShaderModule shaderModule = context->device .createShaderModule (shaderModuleCreateInfo);
694715 vk::PipelineShaderStageCreateInfo shaderCreateInfo (
695716 vk::PipelineShaderStageCreateFlags (),
696717 vk::ShaderStageFlagBits::eCompute,
697718 shaderModule,
698- " main"
719+ " main" ,
720+ &specInfos[opIndex]
699721 );
700722
701723 shaderModules[opIndex] = shaderModule;
0 commit comments