Skip to content

Commit 975a4a3

Browse files
SS-JIAssjia
andauthored
[ET-VK][ez] Support grouped convolutions (#13599)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #13597 * #13596 * #13595 * #13594 * #13593 * #13600 * __->__ #13599 * #13598 Title says it all! Differential Revision: [D80741734](https://our.internmc.facebook.com/intern/diff/D80741734/) --------- Co-authored-by: ssjia <[email protected]>
1 parent 6c1f9fa commit 975a4a3

File tree

6 files changed

+43
-5
lines changed

6 files changed

+43
-5
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3030

3131
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3232

33+
${layout_declare_spec_const(C, "int", "ngroups", "1")}
34+
3335
/*
3436
* Computes a 2D convolution. Each shader invocation calculates the output at
3537
* a single output location.
@@ -74,7 +76,18 @@ void main() {
7476
// Perform the convolution by iterating over the overlay region.
7577
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
7678
const int ic4 = in_group_size / 4;
77-
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += kernel_size.x * 4) {
79+
80+
int z_start = 0;
81+
int z_end = ic4;
82+
if (ngroups > 1) {
83+
const int group_size = (out_limits.z) / ngroups;
84+
const int group_idx = pos.z / group_size;
85+
86+
z_start = ic4 * group_idx;
87+
z_end = z_start + ic4;
88+
}
89+
90+
for (int z4 = z_start; z4 < z_end; ++z4, kstart.x += kernel_size.x * 4) {
7891
for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ky) {
7992
for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4) {
8093
const VEC4_T in_texel = texelFetch(t_in, ivec3(x, y, z4), 0);

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3030

3131
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3232

33+
${layout_declare_spec_const(C, "int", "ngroups", "1")}
34+
3335
/*
3436
* Computes a depthwise convolution. Each shader invocation calculates the
3537
* output at a single output location.

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ layout(push_constant) uniform restrict Block {
3838

3939
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4040

41+
${layout_declare_spec_const(C, "int", "ngroups", "1")}
42+
4143
#extension GL_EXT_control_flow_attributes : require
4244

4345
/*

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ layout(push_constant) uniform restrict Block {
4040

4141
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4242

43+
${layout_declare_spec_const(C, "int", "ngroups", "1")}
44+
4345
#extension GL_EXT_control_flow_attributes : require
4446

4547
/*

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,6 @@ Conv2dMethod get_conv2d_method(
280280
if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) {
281281
return Conv2dMethod::Depthwise;
282282
}
283-
if (groups > 1) {
284-
VK_THROW("aten.convolution.default: groups > 1 is not supported yet!");
285-
}
286283
if (transposed) {
287284
return Conv2dMethod::Transposed;
288285
}
@@ -601,7 +598,7 @@ void add_conv2d_node(
601598
// Push Constants
602599
push_constants,
603600
// Specialization Constants
604-
{},
601+
{utils::safe_downcast<int32_t>(groups_val)},
605602
// Resize Args
606603
{weight_data, stride, padding, dilation, transposed, output_padding},
607604
// Resizing Logic

backends/vulkan/test/op_tests/cases.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,28 @@ def get_conv_inputs():
297297
)
298298

299299
test_cases = [
300+
Test(
301+
self=(1, 64, 256, 256),
302+
weight=(64, 32, 3, 3),
303+
bias=None,
304+
stride=[1, 1],
305+
padding=[1, 1],
306+
dilation=[1, 1],
307+
transposed=False,
308+
output_padding=[0, 0],
309+
groups=2,
310+
),
311+
Test(
312+
self=(1, 16, 3, 3),
313+
weight=(16, 8, 3, 3),
314+
bias=None,
315+
stride=[1, 1],
316+
padding=[1, 1],
317+
dilation=[1, 1],
318+
transposed=False,
319+
output_padding=[0, 0],
320+
groups=2,
321+
),
300322
Test(
301323
self=(1, 6, 40, 50),
302324
weight=(8, 6, 3, 3),

0 commit comments

Comments
 (0)