Skip to content

Commit 203e2b2

Browse files
authored
[WebNN] Support GatherBlockQuantized op (#26020)
This op is used in ORT GenAI and can be decomposed into DequantizeLinear + Gather for WebNN.
1 parent bd28856 commit 203e2b2

File tree

5 files changed

+172
-0
lines changed

5 files changed

+172
-0
lines changed

js/web/docs/webnn-operators.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s
4242
| Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape | |
4343
| Floor | ai.onnx(7-12, 13+) | floor | |
4444
| Gather | ai.onnx(7-10, 11-12, 13+) | gather | |
45+
| GatherBlockQuantized | com.microsoft(1+) | dequantizeLinear, gather | |
4546
| GatherElements | ai.onnx(11-12, 13+) | gatherElements | |
4647
| GatherND | ai.onnx(11, 12, 13+) | gatherND | Only supports 'batch_dims' == 0 |
4748
| Gelu | ai.onnx(20+) | gelu | |
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Copyright (c) Intel Corporation. All rights reserved.
3+
// Licensed under the MIT License.
4+
5+
#include "core/providers/common.h"
6+
#include "core/providers/shared/utils/utils.h"
7+
#include "core/providers/webnn/builders/helper.h"
8+
#include "core/providers/webnn/builders/model_builder.h"
9+
#include "core/providers/webnn/builders/op_builder_factory.h"
10+
11+
#include "base_op_builder.h"
12+
13+
namespace onnxruntime {
14+
namespace webnn {
15+
16+
class GatherBlockQuantizedOpBuilder : public BaseOpBuilder {
17+
// Add operator related.
18+
private:
19+
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
20+
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
21+
22+
// Operator support related.
23+
private:
24+
bool IsOpSupportedImpl(const GraphViewer&, const Node& node,
25+
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
26+
bool HasSupportedInputsImpl(const GraphViewer&, const Node& node,
27+
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
28+
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
29+
const logging::Logger& logger) const override;
30+
};
31+
32+
// WebNN doesn't provide a dedicated op for GatherBlockQuantizedOpBuilder, it can be simply
33+
// decomposed by DequantizeLinear + Gather.
34+
Status GatherBlockQuantizedOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
35+
const Node& node,
36+
const logging::Logger& logger) const {
37+
const auto& input_defs = node.InputDefs();
38+
std::vector<int64_t> input_shape;
39+
std::vector<int64_t> scales_shape;
40+
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
41+
ORT_RETURN_IF_NOT(GetShape(*input_defs[2], scales_shape, logger), "Cannot get scales shape");
42+
const auto input_rank = input_shape.size();
43+
44+
int32_t input_type = 0;
45+
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input data type");
46+
47+
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
48+
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
49+
emscripten::val scales = model_builder.GetOperand(input_defs[2]->Name());
50+
emscripten::val common_options = emscripten::val::object();
51+
52+
NodeAttrHelper helper(node);
53+
const int32_t bits = helper.Get("bits", 4);
54+
const uint32_t gather_axis = SafeInt<uint32_t>(HandleNegativeAxis(helper.Get("gather_axis", 0), input_rank));
55+
56+
// GatherBlockQuantized only supports block-wise quantization, the input and scales should have the same rank.
57+
// So we don't need to reshape scales for broadcasting.
58+
emscripten::val zero_points = emscripten::val::undefined();
59+
if (TensorExists(input_defs, 3)) { // zero_points
60+
zero_points = model_builder.GetOperand(input_defs[3]->Name());
61+
} else {
62+
const uint8_t default_zero_point = bits == 4 ? 0 : 128;
63+
// Create a constant for zero_points, which has the same shape as scales and same type as input.
64+
zero_points = model_builder.CreateOrGetConstant<uint8_t>(input_type,
65+
default_zero_point,
66+
GetNarrowedIntFromInt64<uint32_t>(scales_shape));
67+
}
68+
69+
// dequantized_input = DequantizeLinear(input, scales, zero_points)
70+
common_options.set("label", node.Name() + "_dequantize_input");
71+
emscripten::val dequantized_input = model_builder.GetBuilder().call<emscripten::val>("dequantizeLinear",
72+
input,
73+
scales,
74+
zero_points,
75+
common_options);
76+
77+
// output = Gather(dequantized_input, indices, axis=gather_axis)
78+
common_options.set("label", node.Name() + "_gather");
79+
common_options.set("axis", gather_axis);
80+
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("gather",
81+
dequantized_input,
82+
indices,
83+
common_options);
84+
85+
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
86+
return Status::OK();
87+
}
88+
89+
// Operator support related.
90+
91+
bool GatherBlockQuantizedOpBuilder::IsOpSupportedImpl(const GraphViewer&,
92+
const Node& node,
93+
const WebnnDeviceType /* device_type */,
94+
const logging::Logger& logger) const {
95+
NodeAttrHelper helper(node);
96+
const int32_t bits = helper.Get("bits", 4);
97+
const int32_t block_size = helper.Get("block_size", 128);
98+
99+
if (bits != 4 && bits != 8) {
100+
LOGS(logger, VERBOSE) << "GatherBlockQuantized only supports bits==4 or 8.";
101+
return false;
102+
}
103+
104+
if (block_size < 16 || ((block_size - 1) & block_size) != 0) {
105+
LOGS(logger, VERBOSE) << "GatherBlockQuantized: 'block_size' must be a power of 2 and not less than 16.";
106+
return false;
107+
}
108+
109+
return true;
110+
}
111+
112+
bool GatherBlockQuantizedOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
113+
const emscripten::val& wnn_limits,
114+
const logging::Logger& logger) const {
115+
const auto& input_defs = node.InputDefs();
116+
std::vector<int64_t> input_shape;
117+
std::vector<int64_t> scales_shape;
118+
if (!GetShape(*input_defs[0], input_shape, logger) ||
119+
!GetShape(*input_defs[2], scales_shape, logger)) {
120+
return false;
121+
}
122+
123+
if (input_shape.size() != scales_shape.size()) {
124+
LOGS(logger, VERBOSE) << "GatherBlockQuantized: input and scales must have the same rank.";
125+
return false;
126+
}
127+
128+
const std::string_view op_type = node.OpType();
129+
int32_t input_type = 0;
130+
int32_t scales_type = 0;
131+
if (!GetType(*input_defs[0], input_type, logger) ||
132+
!GetType(*input_defs[2], scales_type, logger)) {
133+
return false;
134+
}
135+
136+
// Only need to check the input data type of ops that consume the inputs of GatherBlockQuantized.
137+
// WebNN dequantizeLinear's input should be same as input. WebNN gather's input should be same as scales input.
138+
return IsDataTypeSupportedByWebNNOp(op_type, "dequantizeLinear", input_type, wnn_limits, "input", "data", logger) &&
139+
IsDataTypeSupportedByWebNNOp(op_type, "gather", scales_type, wnn_limits, "input", "scales", logger);
140+
141+
return true;
142+
}
143+
144+
bool GatherBlockQuantizedOpBuilder::HasSupportedOutputsImpl(const Node& node,
145+
const emscripten::val& wnn_limits,
146+
const logging::Logger& logger) const {
147+
const auto& output_defs = node.OutputDefs();
148+
const std::string_view op_type = node.OpType();
149+
int32_t output_type;
150+
if (!GetType(*output_defs[0], output_type, logger)) {
151+
return false;
152+
}
153+
154+
// Only need to check the output data type of ops that produce the output of GatherBlockQuantized.
155+
// WebNN gather's output should be same as GatherBlockQuantized's output.
156+
return IsDataTypeSupportedByWebNNOp(op_type, "gather", output_type, wnn_limits, "output", "output", logger);
157+
}
158+
159+
void CreateGatherBlockQuantizedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
160+
op_registrations.builders.push_back(std::make_unique<GatherBlockQuantizedOpBuilder>());
161+
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
162+
}
163+
164+
} // namespace webnn
165+
} // namespace onnxruntime

onnxruntime/core/providers/webnn/builders/map_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ const std::map<std::string_view, std::vector<std::string_view>> decomposed_op_ma
5050
{"DynamicQuantizeLinear",
5151
{"Cast", "Clip", "Div", "Max", "Min", "QuantizeLinear", "ReduceMax", "ReduceMin", "Reshape", "Round", "Sub"}},
5252
{"Einsum", {"MatMul", "Mul", "ReduceSum", "Reshape", "Transpose", "Trilu"}},
53+
{"GatherBlockQuantized", {"DequantizeLinear", "Gather"}},
5354
{"GroupQueryAttention",
5455
{"Add", "Cast", "Concat", "CumSum", "Div", "Expand", "Less", "MatMul", "Reshape", "ScatterND",
5556
"Softmax", "Transpose", "Where"}},

onnxruntime/core/providers/webnn/builders/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
108108
CreateGatherOpBuilder("Gather", op_registrations);
109109
}
110110

111+
{ // GatherBlockQuantized
112+
CreateGatherBlockQuantizedOpBuilder("GatherBlockQuantized", op_registrations);
113+
}
114+
111115
{ // GatherElements
112116
CreateGatherElementsOpBuilder("GatherElements", op_registrations);
113117
}

onnxruntime/core/providers/webnn/builders/op_builder_factory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& o
3333
void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3434
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3535
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
36+
void CreateGatherBlockQuantizedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3637
void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3738
void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
3839
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

0 commit comments

Comments
 (0)