Skip to content

Commit a85977d

Browse files
authored
[webgpu] Limit that K must be divisible by 128 to apply dp4a matmul (microsoft#24078)
The DP4AMatMulQuantize shader needs to make sure that K is divisible by 128. Otherwise, we need align the scale to have shape [M, ceil(K / 128)]. To simplify the shader, we limit that K must be divisible by 128 to apply dp4a matmul.
1 parent 12fea57 commit a85977d

File tree

3 files changed

+5
-14
lines changed

3 files changed

+5
-14
lines changed

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,12 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const
1212
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
1313
shader.AddOutput("output", ShaderUsage::UseUniform);
1414
shader.AddOutput("scales", ShaderUsage::UseUniform);
15-
shader.AdditionalImplementation() << R"ADDNL_FN(
16-
fn readInput(offset: u32) -> input_a_value_t
17-
{
18-
if (offset > uniforms.input_size) {
19-
return input_a_value_t(0);
20-
}
21-
return input_a[offset];
22-
}
23-
)ADDNL_FN";
2415
shader.MainFunctionBody() << R"MAIN_FN(
2516
var local_a : array<vec4<input_a_element_t>, 32>;
2617
var max_value:vec4<input_a_element_t> = vec4<input_a_element_t>(0);
2718
for (var idx:u32=0;idx<32;idx+=1)
2819
{
29-
local_a[idx] = readInput(workgroup_idx*32 + idx);
20+
local_a[idx] = input_a[workgroup_idx*32 + idx];
3021
max_value = max(max_value, abs(local_a[idx]));
3122
}
3223
var scale = max(max_value.x, max_value.y);
@@ -279,8 +270,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
279270
Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims);
280271
quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)}})
281272
.AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), 1},
282-
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}})
283-
.AddUniformVariable({static_cast<uint32_t>(M * K / kVec4Components)});
273+
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}});
284274
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));
285275

286276
constexpr uint32_t kTileSize = 64;
@@ -317,7 +307,7 @@ bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context,
317307
bool use_dp4a = context.Device().HasFeature(wgpu::FeatureName::Subgroups) &&
318308
context.AdapterInfo().backendType != wgpu::BackendType::Metal;
319309
return (accuracy_level == 4 && block_size % 32 == 0 &&
320-
batch_count == 1 && components_k == 4 && K % 64 == 0 && N % 16 == 0 &&
310+
batch_count == 1 && components_k == 4 && K % 128 == 0 && N % 16 == 0 &&
321311
!has_zero_points && use_dp4a);
322312
}
323313

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ class DP4AMatMulQuantizeProgram final : public Program<DP4AMatMulQuantizeProgram
1616
public:
1717
DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {}
1818
Status GenerateShaderCode(ShaderHelper& sh) const override;
19-
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32});
2019
};
2120

2221
class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {

onnxruntime/test/contrib_ops/matmul_4bits_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ TEST(MatMulNBits, Float32_Accuracy4) {
389389
TestMatMulNBitsTyped<float, 100, 288, 16, 16, 4>();
390390
TestMatMulNBitsTyped<float, 100, 288, 1024, 16, 4>();
391391
TestMatMulNBitsTyped<float, 100, 288, 1024, 128, 4>();
392+
TestMatMulNBitsTyped<float, 100, 288, 192, 64, 4>();
392393
TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>();
393394
TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>();
394395
TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>();
@@ -458,6 +459,7 @@ TEST(MatMulNBits, Float16_Accuracy4) {
458459
TestMatMulNBitsTyped<MLFloat16, 100, 288, 16, 16, 4>();
459460
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1024, 16, 4>();
460461
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1024, 128, 4>();
462+
TestMatMulNBitsTyped<MLFloat16, 100, 288, 192, 64, 4>();
461463
TestMatMulNBitsTyped<MLFloat16, 100, 288, 93, 32, 4>();
462464
TestMatMulNBitsTyped<MLFloat16, 100, 288, 93, 128, 4>();
463465
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1234, 16, 4>();

0 commit comments

Comments
 (0)