@@ -17,7 +17,7 @@ using onnxruntime::webgpu::ComputeContext;
1717Status GatherBlockQuantizedProgram::GenerateShaderCode (ShaderHelper& shader) const {
1818 const auto & x = shader.AddInput (" input" , ShaderUsage::UseElementTypeAlias);
1919 const auto & x_shape = shader.AddIndices (" input_shape" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
20- const auto & indices = shader.AddInput (" indices" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseIndicesToOffset);
20+ const auto & indices = shader.AddInput (" indices" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseIndicesToOffset | ShaderUsage::UseValueTypeAlias );
2121 const auto & scales = shader.AddInput (" scales" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
2222 const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias);
2323
@@ -38,17 +38,23 @@ Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) con
3838 shader.MainFunctionBody ()
3939 << " let indices_indices = " << output.IndicesGet (" output_indices" , " uniforms.gather_axis" ) << " ;\n " ;
4040 }
41+
42+ shader.MainFunctionBody ()
43+ << " var index = " << indices.GetByIndices (" indices_indices" ) << " ;\n "
44+ << " if (index < 0) { index += indices_value_t(" << x_shape.IndicesGet (" uniforms.input_shape_shape" , gather_axis_) << " );}\n "
45+ << " var data_indices = input_shape_indices_t(0);\n " ;
46+
47+ for (int i = 0 , j = 0 ; i < x_shape.Rank (); i++) {
48+ if (static_cast <int >(i) == gather_axis_) {
49+ shader.MainFunctionBody () << " " << x_shape.IndicesSet (" data_indices" , i, " u32(index)" ) << " ;\n " ;
50+ j += indices.Rank ();
51+ } else {
52+ shader.MainFunctionBody () << " " << x_shape.IndicesSet (" data_indices" , i, output.IndicesGet (" output_indices" , j)) << " ;\n " ;
53+ j++;
54+ }
55+ }
56+
4157 shader.MainFunctionBody ()
42- << " var data_indices = input_shape_indices_t(0);\n "
43- << " for (var i: u32 = 0; i < uniforms.gather_axis; i++) {\n "
44- << " let index = " << output.IndicesGet (" output_indices" , " i" ) << " ;\n "
45- << x_shape.IndicesSet (" data_indices" , " i" , " index" ) << " ;\n };\n "
46- << " var index_from_indices = " << indices.GetByIndices (" indices_indices" ) << " ;\n "
47- << " if (index_from_indices < 0) { index_from_indices += " << x_shape_[gather_axis_] << " ;}\n "
48- << x_shape.IndicesSet (" data_indices" , " uniforms.gather_axis" , " u32(index_from_indices)" ) << " ;\n "
49- << " for (var i = uniforms.gather_axis + 1; i < " << output_shape_.NumDimensions () << " ; i++) {\n "
50- << " let index = " << output.IndicesGet (" output_indices" , " i + " + std::to_string (indices_rank_ - 1 )) << " ;\n "
51- << x_shape.IndicesSet (" data_indices" , " i" , " index" ) << " ;\n };\n "
5258 << " let data_offset = " << x_shape.IndicesToOffset (" data_indices" ) << " ;\n " ;
5359
5460 if (is_4bit) {
0 commit comments