Skip to content

Commit d6219b6

Browse files
authored
[webgpu] Fix GatherBlockQuantized on Intel ADL/TGL platforms (#26526)
### Description The `GatherBlockQuantized` operation was using incorrect `data_indices` during execution on Intel Alder Lake (ADL) and Tiger Lake (TGL) platforms. This change sets the proper `data_indices`, resolving correctness issues encountered with the Phi-4-mini model on these architectures. ### Motivation and Context See above.
1 parent 907ede2 commit d6219b6

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using onnxruntime::webgpu::ComputeContext;
1717
Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) const {
1818
const auto& x = shader.AddInput("input", ShaderUsage::UseElementTypeAlias);
1919
const auto& x_shape = shader.AddIndices("input_shape", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
20-
const auto& indices = shader.AddInput("indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseIndicesToOffset);
20+
const auto& indices = shader.AddInput("indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseIndicesToOffset | ShaderUsage::UseValueTypeAlias);
2121
const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
2222
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias);
2323

@@ -38,17 +38,23 @@ Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) con
3838
shader.MainFunctionBody()
3939
<< "let indices_indices = " << output.IndicesGet("output_indices", "uniforms.gather_axis") << ";\n";
4040
}
41+
42+
shader.MainFunctionBody()
43+
<< "var index = " << indices.GetByIndices("indices_indices") << ";\n"
44+
<< "if (index < 0) { index += indices_value_t(" << x_shape.IndicesGet("uniforms.input_shape_shape", gather_axis_) << ");}\n"
45+
<< "var data_indices = input_shape_indices_t(0);\n";
46+
47+
for (int i = 0, j = 0; i < x_shape.Rank(); i++) {
48+
if (static_cast<int>(i) == gather_axis_) {
49+
shader.MainFunctionBody() << " " << x_shape.IndicesSet("data_indices", i, "u32(index)") << ";\n";
50+
j += indices.Rank();
51+
} else {
52+
shader.MainFunctionBody() << " " << x_shape.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n";
53+
j++;
54+
}
55+
}
56+
4157
shader.MainFunctionBody()
42-
<< "var data_indices = input_shape_indices_t(0);\n"
43-
<< "for (var i: u32 = 0; i < uniforms.gather_axis; i++) {\n"
44-
<< " let index = " << output.IndicesGet("output_indices", "i") << ";\n "
45-
<< x_shape.IndicesSet("data_indices", "i", "index") << ";\n};\n"
46-
<< "var index_from_indices = " << indices.GetByIndices("indices_indices") << ";\n"
47-
<< "if (index_from_indices < 0) { index_from_indices += " << x_shape_[gather_axis_] << ";}\n"
48-
<< x_shape.IndicesSet("data_indices", "uniforms.gather_axis", "u32(index_from_indices)") << ";\n"
49-
<< "for (var i = uniforms.gather_axis + 1; i < " << output_shape_.NumDimensions() << "; i++) {\n"
50-
<< " let index = " << output.IndicesGet("output_indices", "i + " + std::to_string(indices_rank_ - 1)) << ";\n "
51-
<< x_shape.IndicesSet("data_indices", "i", "index") << ";\n};\n"
5258
<< " let data_offset = " << x_shape.IndicesToOffset("data_indices") << ";\n";
5359

5460
if (is_4bit) {

0 commit comments

Comments
 (0)