Skip to content

Commit 607d5e4

Browse files
authored
[WebGPU] Implement Split-K on Conv|MatMul (#26461)
### Description This patch implements the `Split-K` optimization on `Conv|MatMul`. With `Split-K` we can re-arrange the computation into multiple workgroups when `K` is large to increase the parallelism on the platforms that `Split-K` is confirmed to be useful. 1. Support `Split-K` in `MakeMatMulPackedVec4Source()` to split a workgroup with large K into smaller ones. In this patch we only support `Split-K` with `batch_size == 1` and `vec4` on `Conv|MatMul`. 2. Support `Split-K` in `MatMulWriteFnSource()` (add the partial result to output with atomic built-in functions) 3. Implement `SplitKConfig` to decide whether `Split-K` should be used or not, and all the related thresholds. 4. Implement `MatMulFillBiasBeforeSplitKProgram` to initialize the output with `bias` or 0 when `Split-K` is used. ### Motivation and Context In current implementation, when `K` or `dim_inner` is large, in each invocation we always do the computation one by one in a very large loop, which may not make full use of all EUs on a GPU. With `Split-K` we can split such large amount of computation (`K`) into multiple workgroups with less computation (`kSplitK`, smaller than K), which can greatly improve the parallelism. With this patch we can get about 15% performance improvement on `efficientnet-lite-f16-demo` and 9% improvement on `mobilenetv2-12-f16-demo` on Lunar Lake and Meteor Lake.
1 parent 81a04ca commit 607d5e4

File tree

18 files changed

+633
-48
lines changed

18 files changed

+633
-48
lines changed

onnxruntime/core/providers/webgpu/compute_context.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,9 @@ const webgpu::BufferManager& ComputeContext::BufferManagerAccessor::Get(const Co
2020
return context.ep_.BufferManager();
2121
}
2222

23+
const SplitKConfig& ComputeContext::GetSplitKConfig() {
24+
return webgpu_context_.GetSplitKConfig();
25+
}
26+
2327
} // namespace webgpu
2428
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,13 @@ class ComputeContext final {
152152
return webgpu_context_.Run(*this, program);
153153
}
154154

155+
//
156+
// Get Split-K configuration.
157+
//
158+
// `split_k_config_` won't be initialized until the first call to this method.
159+
//
160+
const SplitKConfig& GetSplitKConfig();
161+
155162
private:
156163
WebGpuContext& webgpu_context_;
157164
OpKernelContext& kernel_context_;

onnxruntime/core/providers/webgpu/math/gemm_utils.cc

Lines changed: 120 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace webgpu {
1313
// which are used in the MatMulWriteFnSource function.
1414
namespace {
1515

16-
void HanldeMaybeHaveBiasForGEMM(ShaderHelper& shader,
16+
void HandleMaybeHaveBiasForGEMM(ShaderHelper& shader,
1717
const ShaderVariableHelper& output,
1818
bool has_bias,
1919
int c_components,
@@ -53,6 +53,70 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader,
5353
<< output.SetByIndices("coords", "value") << "\n";
5454
}
5555

56+
void HandleMatMulWithSplitK(
57+
ShaderHelper& shader,
58+
ProgramVariableDataType output_variable_type) {
59+
shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n";
60+
61+
// With Split-K, the final output will be the sum of the sub-outputs from multiple workgroups,
62+
// so we must add them with atomic built-in functions. Because currently WebGPU doesn't support
63+
// atomic built-in functions on `f32` or `f16`, we implement the `atomicAdd` on `f32` and `f16`
64+
// with `atomicLoad` and `atomicCompareExchangeWeak`:
65+
// 1. Get `old_output_i32` from `output[offset]` with `atomicLoad`.
66+
// 2. Convert `old_output_i32` into `f32` (`old_output_f32`) or `vec2h` (`old_output_vec2h`).
67+
// 3. Add incoming `value` into `old_output_f32` or `old_output_vec2h`.
68+
// 4. Convert the result of step 3 into `i32` values.
69+
// 5. Try assigning the result of step 4 into `output[offset]` with `atomicCompareExchangeWeak`
70+
// and `old_output_i32`. The assignment will fail if at this time `output[offset]` is not
71+
// equal to `old_output_i32` (it is updated in another invocation). If the assignment fails
72+
// we have to go to step 1 and repeat all the above steps.
73+
switch (output_variable_type) {
74+
case ProgramVariableDataType::Float32x4: {
75+
shader.AdditionalImplementation() << R"(
76+
let offset0 = i2o_output(coords) * 4u;
77+
for (var i = 0u; i < 4u; i++) {
78+
let offset = offset0 + i;
79+
while (true) {
80+
let old_output_i32 = atomicLoad(&output[offset]);
81+
let old_output_f32 = bitcast<f32>(old_output_i32);
82+
let new_output_f32 = old_output_f32 + value[i];
83+
let new_output_i32 = bitcast<i32>(new_output_f32);
84+
let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32);
85+
if (output_compare_exchange.old_value == old_output_i32) {
86+
break;
87+
}
88+
}
89+
}
90+
)";
91+
break;
92+
}
93+
case ProgramVariableDataType::Float16x4: {
94+
shader.AdditionalImplementation() << R"(
95+
let offset0 = i2o_output(coords) * 2u;
96+
var vec2h_values : array<vec2h, 2>;
97+
vec2h_values[0] = value.xy;
98+
vec2h_values[1] = value.zw;
99+
for (var i = 0u; i < 2u; i++) {
100+
let offset = offset0 + i;
101+
while (true) {
102+
let old_output_i32 = atomicLoad(&output[offset]);
103+
let old_output_vec2h = bitcast<vec2h>(old_output_i32);
104+
let new_output_vec2h = old_output_vec2h + vec2h_values[i];
105+
let new_output_i32 = bitcast<i32>(new_output_vec2h);
106+
let output_compare_exchange = atomicCompareExchangeWeak(&output[offset], old_output_i32, new_output_i32);
107+
if (output_compare_exchange.old_value == old_output_i32) {
108+
break;
109+
}
110+
}
111+
}
112+
)";
113+
break;
114+
}
115+
default:
116+
break;
117+
}
118+
}
119+
56120
} // namespace
57121

58122
void MatMulReadFnSource(ShaderHelper& shader,
@@ -125,7 +189,9 @@ void MatMulWriteFnSource(ShaderHelper& shader,
125189
int output_components,
126190
bool c_is_scalar,
127191
std::string activation_snippet,
128-
bool is_channels_last) {
192+
bool is_channels_last,
193+
bool use_split_k,
194+
ProgramVariableDataType output_variable_type) {
129195
shader.AdditionalImplementation()
130196
<< "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n";
131197

@@ -134,8 +200,17 @@ void MatMulWriteFnSource(ShaderHelper& shader,
134200
shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n"
135201
<< " var value = valueIn; \n";
136202

137-
if (is_gemm) {
138-
HanldeMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar);
203+
if (use_split_k) {
204+
// Set output when MatMul is performed with Split-K.
205+
// When Split-K is used in MatMul, the bias will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram`
206+
// instead of here, so `has_bias` and `is_channels_last` is not used for Split-K. Note that we
207+
// still need to handle `has_bias` (and `is_channels_last` in the future) in
208+
// `MatMulFillBiasOrZeroBeforeSplitKProgram`.
209+
ORT_ENFORCE(!has_bias, "Bias is not supported in MatMulProgram when Split-K is enabled.");
210+
ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled.");
211+
HandleMatMulWithSplitK(shader, output_variable_type);
212+
} else if (is_gemm) {
213+
HandleMaybeHaveBiasForGEMM(shader, output, has_bias, c_components, output_components, c_is_scalar);
139214
} else {
140215
HandleMaybeBiasForMatMul(shader, output, has_bias, activation_snippet, is_channels_last);
141216
}
@@ -159,9 +234,6 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
159234
uint32_t tile_inner,
160235
bool split_k,
161236
uint32_t split_dim_inner) {
162-
ORT_UNUSED_PARAMETER(split_k);
163-
ORT_UNUSED_PARAMETER(split_dim_inner);
164-
165237
const std::string type_string = MakeScalarOrVectorType(4 /*components */, data_type);
166238

167239
std::string write_data_to_sub_a_vec4_snippet =
@@ -208,14 +280,51 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
208280
<< " let tileCol = i32(local_id.x);\n"
209281
<< " let globalRow = i32(global_id.y) * rowPerThread;\n"
210282
<< " let globalCol = i32(global_id.x);\n"
211-
<< " let batch = i32(global_id.z);\n"
212-
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")
213283
<< " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
214284
<< " let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n"
215-
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
216-
<< " var kStart = 0;\n"
217285
<< " var acc: array<vec4<" << data_type << ">, rowPerThread>;\n";
218286

287+
if (split_k) {
288+
// With Split-K, the original "workgroup" (with dispatch_z == 1 in API side) is split into
289+
// multiple ones, and in the current workgroup we only compute `kSplitK` elements starting from
290+
// `kSplitK * i32(global_id.z)`.
291+
//
292+
// For example: considering computing Y = (X * W + B) in one workgroup.
293+
// Let kSplitK = 2, B = [d1, d2]
294+
// Let X = [[a1 a1 b1 b1 c1 c1] = [ A1 B1 C1 ], W = [[a2 a2] = [ A2
295+
// [a1 a1 b1 b1 c1 c1]] [a2 a2] B2
296+
// [b2 b2] C2 ]
297+
// [b2 b2]
298+
// [c2 c2]
299+
// [c2 c2]]
300+
//
301+
// With Split-K:
302+
// 1. Initialize output Y with B in `MatMulFillBiasOrZeroBeforeSplitKProgram`: Y = [[d1, d2]
303+
// [d1, d2]]
304+
// 2. Split the original 1 workgroup into 3 workgroups (now `dispatch_z = 3` in API side)
305+
// Workgroup1: compute (A1 * A2) Workgroup2: compute (B1 * B2)
306+
// Workgroup3: compute (C1 * C2)
307+
// In each workgroup:
308+
// - `num_tiles` is computed with `kSplitK`, and `kStart` is computed with `global_id.z`
309+
// - When the computation in each workgroup is completed, add the result to Y with several
310+
// atomic built-in functions in `HandleMatMulWithSplitK()`.
311+
shader.MainFunctionBody()
312+
<< "const kSplitK = " << split_dim_inner << ";\n"
313+
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n"
314+
<< " var kStart = kSplitK * i32(global_id.z);\n"
315+
316+
// When Split-K is used, `batch` should always be 0 and `global_id.z` is used to indicate
317+
// the index of split-k instead of batch.
318+
<< " let batch = 0;\n"
319+
<< " let batchIndices = 0u;\n";
320+
} else {
321+
shader.MainFunctionBody()
322+
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
323+
<< " var kStart = 0;\n"
324+
<< " let batch = i32(global_id.z);\n"
325+
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "");
326+
}
327+
219328
// Loop over shared dimension.
220329
shader.MainFunctionBody() << " let tileRowB = localRow * " << row_per_thread_b << ";\n";
221330

onnxruntime/core/providers/webgpu/math/gemm_utils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ void MatMulWriteFnSource(ShaderHelper& shader,
2424
int output_components,
2525
bool c_is_scalar,
2626
std::string activation_snippet = "",
27-
bool is_channels_last = false);
27+
bool is_channels_last = false,
28+
bool use_split_k = false,
29+
ProgramVariableDataType output_variable_type = ProgramVariableDataType::Float32x4);
2830

2931
// The two following functions are used to generate shader code for vec4 and scalar.
3032
// It is used in GEMM, Matmul, and Conv.

onnxruntime/core/providers/webgpu/math/matmul.cc

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,14 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
161161
const auto* bias = context.Input(2);
162162
inputs.push_back(bias);
163163
}
164-
auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false);
165164

166-
return context.RunProgram(program);
165+
return ComputeMatMul(&context, Activation(), inputs, output_tensor, false);
167166
}
168167

169-
MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output_tensor, bool is_channels_last,
170-
const TensorShape& input_a_reshape,
171-
const TensorShape& input_b_reshape) {
168+
Status ComputeMatMul(ComputeContext* context,
169+
const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output_tensor, bool is_channels_last,
170+
const TensorShape& input_a_reshape,
171+
const TensorShape& input_b_reshape) {
172172
const auto* a = inputs[0];
173173
const auto* b = inputs[1];
174174
bool has_bias = inputs.size() > 2;
@@ -226,31 +226,97 @@ MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<cons
226226
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
227227
const uint32_t dispatch_y = narrow<uint32_t>((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
228228
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
229-
const uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
230-
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
229+
uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
230+
(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
231231

232232
const int components = is_vec4 ? 4 : 1;
233233
const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components);
234234
const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components);
235235
const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components});
236236

237-
MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last};
238-
program
239-
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last)
237+
ProgramOutput output(output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components);
238+
const Tensor* bias = has_bias ? inputs[2] : nullptr;
239+
bool use_bias_in_matmul = has_bias;
240+
uint32_t split_dim_inner = 1;
241+
242+
const SplitKConfig& split_k_config = context->GetSplitKConfig();
243+
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);
244+
if (need_split_k) {
245+
ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1.");
246+
ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format.");
247+
ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format.");
248+
249+
// Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled.
250+
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, output_shape_temp);
251+
ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program));
252+
253+
// `bias` has been handled in the execution of `fill_bias_program` so we don't need to set
254+
// `bias` again in `MatMulProgram`.
255+
use_bias_in_matmul = false;
256+
257+
// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
258+
// number of splits along `dim_inner`.
259+
// TODO: avoid using `global_id.xxx` or `workgroup_id.xxx` in `MatMulProgram` when we normalize
260+
// the dispatch size with `ProgramManager::NormalizeDispatchGroupSize()` for `MatMulProgram`.
261+
split_dim_inner = split_k_config.GetSplitDimInner();
262+
dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;
263+
264+
// The output should be declared in atomic types in `MatMulProgram` for the use of atomic
265+
// built-in functions.
266+
output.is_atomic = true;
267+
}
268+
269+
MatMulProgram matmul_program{activation, use_bias_in_matmul, is_vec4, elements_per_thread, is_channels_last, split_dim_inner};
270+
matmul_program
271+
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner)
240272
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
241273
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
242-
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}})
243274
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}})
244275
.AddIndices(outer_dims)
245276
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
246-
.SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z);
277+
.SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z)
278+
.AddOutput(std::move(output));
247279

248-
if (has_bias) {
280+
if (use_bias_in_matmul) {
249281
auto bias_components = is_channels_last ? components : 1;
250-
const auto* bias = inputs[2];
251282
TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
252-
program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components});
283+
matmul_program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components});
284+
}
285+
286+
return context->RunProgram(matmul_program);
287+
}
288+
289+
MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram(
290+
const Tensor* bias,
291+
Tensor* output,
292+
const TensorShape& output_shape_vec4) {
293+
const bool has_bias = bias != nullptr;
294+
295+
// Currently we only support bias in vec4 and channels last format for Split-K MatMul.
296+
constexpr uint32_t bias_components = 4;
297+
MatMulFillBiasOrZeroBeforeSplitKProgram program(has_bias);
298+
299+
const uint32_t dim_a_outer = narrow<uint32_t>(output_shape_vec4[output_shape_vec4.NumDimensions() - 2]);
300+
const uint32_t dim_b_outer_vec4 = narrow<uint32_t>(output_shape_vec4[output_shape_vec4.NumDimensions() - 1]);
301+
302+
// Fill one value (currently only vec4) per invocation. Now we use default workgroup size (64) for
303+
// this program.
304+
const uint32_t total_outputs_vec4 = dim_a_outer * dim_b_outer_vec4;
305+
const uint32_t dispatch_x = (total_outputs_vec4 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
306+
307+
// To reuse `MatMulWriteFnSource()` we need to set `dim_a_outer` and `dim_b_outer` in scalar
308+
// instead of vec4, while use `output_shape_vec4` directly as the output shape.
309+
const uint32_t dim_b_outer = narrow<uint32_t>(dim_b_outer_vec4 * bias_components);
310+
program.CacheHint(has_bias)
311+
.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_vec4, static_cast<int32_t>(bias_components)})
312+
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}})
313+
.SetDispatchGroupSize(dispatch_x);
314+
315+
if (has_bias) {
316+
const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
317+
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast<int32_t>(bias_components)});
253318
}
319+
254320
return program;
255321
}
256322

onnxruntime/core/providers/webgpu/math/matmul.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@
1414
namespace onnxruntime {
1515
namespace webgpu {
1616

17-
MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output, bool is_channels_last,
18-
const TensorShape& input_a_reshape = TensorShape(),
19-
const TensorShape& input_b_reshape = TensorShape());
17+
Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output, bool is_channels_last,
18+
const TensorShape& input_a_reshape = TensorShape(),
19+
const TensorShape& input_b_reshape = TensorShape());
20+
21+
MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram(
22+
const Tensor* bias,
23+
Tensor* output,
24+
const TensorShape& output_shape_vec4);
2025

2126
class MatMul final : public WebGpuKernel {
2227
public:

0 commit comments

Comments
 (0)