Skip to content

Commit 6f9a236

Browse files
committed
[ET-VK] Enable buffer implementation of aten.linear
## Changes As title. Extend the existing buffer implementation of `matmul` to support the linear operator as well. Differential Revision: [D65277712](https://our.internmc.facebook.com/intern/diff/D65277712/) [ghstack-poisoned]
1 parent 7513dfa commit 6f9a236

File tree

3 files changed

+59
-16
lines changed

3 files changed

+59
-16
lines changed

backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,23 @@ ${define_required_extensions(DTYPE)}
1616

1717
layout(std430) buffer;
1818

19-
${layout_declare_tensor(0, "w", "t_out", DTYPE, "buffer")}
20-
${layout_declare_tensor(1, "r", "t_mat1", DTYPE, "buffer")}
21-
${layout_declare_tensor(2, "r", "t_mat2", DTYPE, "buffer")}
22-
${layout_declare_ubo(3, "ivec4", "out_sizes")}
23-
${layout_declare_ubo(4, "ivec4", "out_strides")}
24-
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
25-
${layout_declare_ubo(6, "ivec4", "mat1_strides")}
26-
${layout_declare_ubo(7, "ivec4", "mat2_sizes")}
27-
${layout_declare_ubo(8, "ivec4", "mat2_strides")}
28-
${layout_declare_ubo(9, "int", "out_numel")}
19+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
20+
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer")}
21+
${layout_declare_tensor(B, "r", "t_mat2", DTYPE, "buffer")}
22+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
23+
${layout_declare_ubo(B, "ivec4", "out_strides")}
24+
${layout_declare_ubo(B, "ivec4", "mat1_sizes")}
25+
${layout_declare_ubo(B, "ivec4", "mat1_strides")}
26+
${layout_declare_ubo(B, "ivec4", "mat2_sizes")}
27+
${layout_declare_ubo(B, "ivec4", "mat2_strides")}
28+
${layout_declare_ubo(B, "int", "out_numel")}
2929

3030
#include "indexing_utils.h"
3131

3232
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3333

34+
${layout_declare_spec_const(C, "int", "mat2_is_transposed", "0")}
35+
3436
void main() {
3537
const ivec4 out_bufix = ivec4(
3638
gl_GlobalInvocationID.x,
@@ -44,15 +46,28 @@ void main() {
4446

4547
int mat1_bufi = tidx_to_bufi(
4648
ivec4(0, out_bufix.y, out_bufix.z, out_bufix.w), mat1_strides);
47-
int mat2_bufi = tidx_to_bufi(
48-
ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides);
49+
int mat2_bufi;
50+
if (mat2_is_transposed > 0) {
51+
mat2_bufi = tidx_to_bufi(
52+
ivec4(0, out_bufix.x, 0, 0), mat2_strides);
53+
} else {
54+
mat2_bufi = tidx_to_bufi(
55+
ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides);
56+
}
57+
58+
int mat2_stride;
59+
if (mat2_is_transposed > 0) {
60+
mat2_stride = mat2_strides.x;
61+
} else {
62+
mat2_stride = mat2_strides.y;
63+
}
4964

5065
T sum = T(0.0);
5166
for (int i = 0; i < mat1_sizes.x; ++i) {
5267
sum += t_mat1[mat1_bufi] * t_mat2[mat2_bufi];
5368

5469
mat1_bufi += mat1_strides.x;
55-
mat2_bufi += mat2_strides.y;
70+
mat2_bufi += mat2_stride;
5671
}
5772

5873
const int out_bufi = tidx_to_bufi(out_bufix, out_strides);

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ void add_matmul_naive_buffer_node(
7777
graph.size_at<uint32_t>(-2, out),
7878
graph.size_at<uint32_t>(-3, out) * graph.size_at<uint32_t>(-4, out)};
7979

80+
int mat2_is_transposed_val = 0;
81+
if (mat2_is_transposed != kDummyValueRef &&
82+
graph.get_bool(mat2_is_transposed)) {
83+
mat2_is_transposed_val = 1;
84+
}
85+
8086
graph.execute_nodes().emplace_back(new DispatchNode(
8187
graph,
8288
VK_KERNEL_FROM_STR(kernel_name),
@@ -96,7 +102,7 @@ void add_matmul_naive_buffer_node(
96102
graph.numel_ubo(out),
97103
},
98104
// Specialization Constants
99-
{},
105+
{mat2_is_transposed_val},
100106
// Resizing Logic
101107
resize_matmul_node,
102108
{mat2_is_transposed}));

backends/vulkan/test/op_tests/cases.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ def get_addmm_inputs():
126126
]
127127

128128

129-
@register_test_suite("aten.linear.default")
130-
def get_linear_inputs():
129+
def get_linear_texture_inputs():
131130
MKN_list = common_MKN_list
132131

133132
inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list]
@@ -142,9 +141,32 @@ def get_linear_inputs():
142141
"utils::kWidthPacked",
143142
"utils::kChannelsPacked",
144143
]
144+
test_suite.test_name_suffix = "texture"
145+
return test_suite
146+
147+
148+
def get_linear_buffer_inputs():
149+
MKN_list = common_MKN_list
150+
151+
inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list]
152+
inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list]
153+
154+
test_suite = VkTestSuite(inputs_list)
155+
test_suite.dtypes = ["at::kFloat"]
156+
test_suite.layouts = [
157+
"utils::kWidthPacked",
158+
"utils::kChannelsPacked",
159+
]
160+
test_suite.storage_types = ["utils::kBuffer"]
161+
test_suite.test_name_suffix = "buffer"
145162
return test_suite
146163

147164

165+
@register_test_suite("aten.linear.default")
166+
def get_linear_test_suites():
167+
return [get_linear_texture_inputs(), get_linear_buffer_inputs()]
168+
169+
148170
@register_test_suite("aten._weight_int8pack_mm.default")
149171
def get_weight_int8pack_mm_inputs():
150172
MKN_list = common_MKN_list

0 commit comments

Comments
 (0)