Skip to content

Commit 7cf3ee9

Browse files
committed
feat: tweaks.
1 parent 1f9c5f9 commit 7cf3ee9

File tree

3 files changed

+42
-19
lines changed

3 files changed

+42
-19
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ VULKAN_SHADER_BINS := $(VULKAN_SHADER_SRCS:.comp=.spv)
7070
DEPS += $(VULKAN_SHADER_BINS)
7171

7272
%.spv: %.comp
73-
$(CGLSLC) -c $< -o $@ --target-env=vulkan1.1
73+
$(CGLSLC) -c $< -o $@ --target-env=vulkan1.2
7474
nn-vulkan-test: src/nn/nn-vulkan-test.cpp nn-quants.o nn-core.o nn-executor.o nn-vulkan.o ${DEPS}
7575
$(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS)
7676
endif

src/nn/nn-vulkan.cpp

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,26 @@ static std::vector<NnVulkanBatchInfo> buildBatchInfo(NnOpConfig *opConfig, NnVul
522522
return offset;
523523
}
524524

525+
static NnUint roundUpPow2(NnUint n, NnUint min, NnUint max) {
526+
NnUint p = 1;
527+
while (p << 1 <= n) p <<= 1;
528+
if (p < n) p <<= 1;
529+
if (p < min) p = min;
530+
if (p > max) p = max;
531+
return p;
532+
}
533+
534+
static uint32_t resolveShaderNThreads(const NnOpConfig *opConfig, const NnSize2D inputSize) {
535+
if (opConfig->code == OP_MATMUL) {
536+
if (opConfig->weightSize.floatType == F_Q40) {
537+
constexpr NnUint maxThreads = 256; // Shader constant
538+
NnUint t = roundUpPow2(inputSize.x / (Q40_BLOCK_SIZE * 2), 32, maxThreads);
539+
return t;
540+
}
541+
}
542+
return 0;
543+
}
544+
525545
static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSize, NnUint *groupCount, const NnSize2D inputSize, const NnSize2D outputSize) {
526546
groupCount[0] = 1;
527547
groupCount[1] = batchSize;
@@ -531,7 +551,7 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
531551
if (outputSize.floatType == F_Q80) {
532552
groupCount[2] = outputSize.x / Q80_BLOCK_SIZE;
533553
} else {
534-
constexpr NnUint chunkSize = 4;
554+
constexpr NnUint chunkSize = 4; // Shader constant
535555
groupCount[2] = outputSize.x / chunkSize;
536556
}
537557
} else if (opConfig->code == OP_MERGE_ADD) {
@@ -542,9 +562,8 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
542562
}
543563
} else if (opConfig->code == OP_MATMUL) {
544564
if (opConfig->weightSize.floatType == F_Q40) {
545-
// Must be synced with the shader
546-
constexpr NnUint tileSizeN = 2;
547-
constexpr NnUint tileSizeD = 8;
565+
constexpr NnUint tileSizeN = 2; // Shader constant
566+
constexpr NnUint tileSizeD = 8; // Shader constant
548567
const NnUint blockSize = getBlockSize(opConfig->weightSize.floatType);
549568
assert(opConfig->weightSize.y % (tileSizeN * blockSize) == 0);
550569
assert(opConfig->weightSize.x % tileSizeD == 0);
@@ -564,7 +583,7 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
564583
opConfig->code == OP_SILU ||
565584
opConfig->code == OP_SHIFT
566585
) {
567-
constexpr NnUint chunkSize = 4;
586+
constexpr NnUint chunkSize = 4; // Shader constant
568587
assert(outputSize.x % chunkSize == 0);
569588
groupCount[2] = outputSize.x / chunkSize;
570589
}
@@ -663,12 +682,10 @@ NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanD
663682

664683
std::vector<vk::PipelineShaderStageCreateInfo> shaderCreateInfos(segmentConfig->nOps);
665684

666-
constexpr NnUint maxConsts = 3;
667-
std::vector<NnUint> nConsts(segmentConfig->nOps);
668-
std::vector<int> consts(segmentConfig->nOps * maxConsts);
669685
std::vector<vk::SpecializationInfo> specInfos(segmentConfig->nOps);
670-
std::vector<vk::SpecializationMapEntry> specMapEntries(segmentConfig->nOps * maxConsts);
671-
686+
std::vector<vk::SpecializationMapEntry> specEntries(segmentConfig->nOps);
687+
std::vector<uint32_t> nThreads(segmentConfig->nOps);
688+
672689
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
673690
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
674691
NnSize2D inputSize = data->resolveBufferSize(&opConfig->input);
@@ -690,12 +707,17 @@ NnVulkanDeviceSegment::NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanD
690707
code.data()
691708
);
692709

710+
nThreads[opIndex] = resolveShaderNThreads(opConfig, inputSize);
711+
specEntries[opIndex] = vk::SpecializationMapEntry(0, 0, sizeof(uint32_t));
712+
specInfos[opIndex] = vk::SpecializationInfo(1, &specEntries[opIndex], sizeof(uint32_t), &nThreads[opIndex]);
713+
693714
vk::ShaderModule shaderModule = context->device.createShaderModule(shaderModuleCreateInfo);
694715
vk::PipelineShaderStageCreateInfo shaderCreateInfo(
695716
vk::PipelineShaderStageCreateFlags(),
696717
vk::ShaderStageFlagBits::eCompute,
697718
shaderModule,
698-
"main"
719+
"main",
720+
&specInfos[opIndex]
699721
);
700722

701723
shaderModules[opIndex] = shaderModule;

src/nn/vulkan/matmul-forward-q80-q40-f32.comp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
#extension GL_EXT_shader_16bit_storage : enable
55
#extension GL_EXT_shader_explicit_arithmetic_types : enable
66

7-
#define N_THREADS 64
7+
#define MAX_THREADS 256
88
#define N_BATCHES 32
99
#define TILE_SIZE_X 2
1010
#define TILE_SIZE_D 8
1111

1212
#define Q80_Q40_BLOCK_SIZE 32
1313

14-
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
14+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1515

1616
struct BatchInfo {
1717
uint inputOffset;
@@ -35,17 +35,18 @@ layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
3535
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
3636
layout(binding = 3) readonly buffer weightBuffer { BlockQ40 weight[]; };
3737

38-
shared float16_t sums[N_THREADS * TILE_SIZE_D];
38+
shared float16_t sums[MAX_THREADS * TILE_SIZE_D];
3939

4040
void main() {
41+
const uint nThreads = gl_WorkGroupSize.x;
4142
const uint threadIndex = gl_LocalInvocationID.x;
4243
const uint batchIndex = gl_WorkGroupID.y;
4344
const uint workGroupIndex = gl_WorkGroupID.z;
4445
const BatchInfo info = infos[batchIndex];
4546

4647
const uint xTiles = info.inputSizeX / TILE_SIZE_X;
47-
const uint xSlice = xTiles / N_THREADS;
48-
const uint xRest = xTiles % N_THREADS;
48+
const uint xSlice = xTiles / nThreads;
49+
const uint xRest = xTiles % nThreads;
4950

5051
const uint inputOffset = info.inputOffset;
5152
const uint inputSizeX = info.inputSizeX;
@@ -97,15 +98,15 @@ void main() {
9798

9899
barrier();
99100

100-
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
101+
for (uint i = nThreads / 2; i > 0; i >>= 1) {
101102
for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
102103
if (threadIndex < i) {
103104
sums[threadIndex * TILE_SIZE_D + dt] += sums[(threadIndex + i) * TILE_SIZE_D + dt];
104105
}
105106
}
106107
barrier();
107108
}
108-
for (uint dt = threadIndex; dt < TILE_SIZE_D; dt += N_THREADS) {
109+
for (uint dt = threadIndex; dt < TILE_SIZE_D; dt += nThreads) {
109110
y[outputOffset + d + dt] = float(sums[dt]);
110111
}
111112
}

0 commit comments

Comments
 (0)