Skip to content

Commit f2b0595

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
refactor mm and linear implementation (#4011)
Summary: Pull Request resolved: #4011 Our existing `mm` uses a generalized shader for both 2d and 3d input. nathanaelsee found that edge graphs smartly convert some 3d input into 2d, i.e. the needed op becomes `mm` instead of `bmm`. To optimize performance, we split the current implementation for 2d and 3d respectively using template variants. Reviewed By: nathanaelsee Differential Revision: D58629158 fbshipit-source-id: 4300cf086424618aecf0637da59bba46db05e42f
1 parent 524dd49 commit f2b0595

File tree

7 files changed

+118
-12
lines changed

7 files changed

+118
-12
lines changed

backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
$if MAT2_IS_TRANSPOSED:
1414
#define MAT2_IS_TRANSPOSED
1515

16+
$if BATCH_MODE:
17+
#define BATCH_MODE
18+
1619
#include "indexing_utils.h"
1720
#include "matmul.h"
1821

@@ -52,12 +55,20 @@ void main() {
5255
return;
5356
}
5457

55-
FloatMatrix results = matmul_partial_4x4(
58+
$if BATCH_MODE:
59+
FloatMatrix_3d results = matmul_partial_4x4x4(
5660
im_mat1,
5761
im_mat2,
5862
pos,
5963
out_sizes[2],
6064
in_limits[0]);
65+
$else:
66+
FloatMatrix_2d results = matmul_partial_4x4(
67+
im_mat1,
68+
im_mat2,
69+
pos,
70+
out_sizes[2],
71+
in_limits[0]);
6172

6273
for (int idx_c = 0; idx_c < FOUR; idx_c++) {
6374
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
@@ -71,14 +82,21 @@ void main() {
7182
self_sizes.y == 1);
7283

7384
// results is in transposed order w.r.t. the desired output
74-
imageStore(
85+
$if BATCH_MODE:
86+
imageStore(
7587
im_out,
7688
out_pos,
7789
vec4(
7890
beta * self_texel.x + alpha * results.data[idx_c][idx_r][0],
7991
beta * self_texel.x + alpha * results.data[idx_c][idx_r][1],
8092
beta * self_texel.x + alpha * results.data[idx_c][idx_r][2],
8193
beta * self_texel.x + alpha * results.data[idx_c][idx_r][3]));
94+
$else:
95+
imageStore(
96+
im_out,
97+
out_pos,
98+
vec4(
99+
beta * self_texel.x + alpha * results.data[idx_c][idx_r], 0.0, 0.0, 0.0));
82100
}
83101
}
84102
}

backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ addmm_optimized:
1010
NDIM: 3
1111
PACKING: C_packed
1212
MAT2_IS_TRANSPOSED: false
13+
BATCH_MODE: false
1314
generate_variant_forall:
1415
DTYPE:
1516
- VALUE: float
@@ -18,3 +19,8 @@ addmm_optimized:
1819
- NAME: addmm_optimized
1920
- NAME: linear_optimized
2021
MAT2_IS_TRANSPOSED: true
22+
- NAME: batch_addmm_optimized
23+
BATCH_MODE: true
24+
- NAME: batch_linear_optimized
25+
MAT2_IS_TRANSPOSED: true
26+
BATCH_MODE: true

backends/vulkan/runtime/graph/ops/glsl/matmul.h

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313
// we avoid mat4 and vec4 usage here as they compile to much less efficient
1414
// SPIR-V
15-
struct FloatMatrix {
15+
struct FloatMatrix_2d {
16+
float data[FOUR][FOUR];
17+
};
18+
19+
struct FloatMatrix_3d {
1620
float data[FOUR][FOUR][FOUR];
1721
};
1822

@@ -146,13 +150,56 @@ vec4 get_texel_C_packed(
146150
return self_texel;
147151
}
148152

149-
FloatMatrix matmul_partial_4x4(
153+
FloatMatrix_2d matmul_partial_4x4(
154+
sampler3D im_mat1,
155+
sampler3D im_mat2,
156+
const ivec3 pos,
157+
const int batch_size,
158+
const int K_texel_len) {
159+
FloatMatrix_2d results;
160+
for (int i = 0; i < FOUR; i++) {
161+
for (int j = 0; j < FOUR; j++) {
162+
results.data[i][j] = 0.0f;
163+
}
164+
}
165+
vec4 im_mat1_partial_load[FOUR];
166+
vec4 im_mat2_partial_load[FOUR];
167+
168+
for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) {
169+
for (int offset = 0; offset < FOUR; offset++) {
170+
// read and cache 4x4 tile of im_mat1
171+
const int mat1_y = (FOUR * pos.y) + offset;
172+
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, 0);
173+
im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0);
174+
// read and cache 4x4 tile of im_mat2
175+
#ifdef MAT2_IS_TRANSPOSED
176+
const int mat2_y = (FOUR * pos.x) + offset;
177+
const ivec3 mat2_pos = ivec3(mat1_x, mat2_y, 0);
178+
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
179+
#else
180+
const int mat2_x = (FOUR * pos.x) + offset;
181+
const ivec3 mat2_pos = ivec3(mat2_x, mat1_x, 0);
182+
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
183+
#endif
184+
}
185+
// perform partial dot products and add partial result to results
186+
for (int out_row = 0; out_row < FOUR; out_row++) {
187+
for (int out_col = 0; out_col < FOUR; out_col++) {
188+
results.data[out_row][out_col] +=
189+
dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]);
190+
}
191+
}
192+
}
193+
return results;
194+
}
195+
196+
FloatMatrix_3d matmul_partial_4x4x4(
150197
sampler3D im_mat1,
151198
sampler3D im_mat2,
152199
const ivec3 pos,
153200
const int batch_size,
154201
const int K_texel_len) {
155-
FloatMatrix results;
202+
FloatMatrix_3d results;
156203
for (int i = 0; i < FOUR; i++) {
157204
for (int j = 0; j < FOUR; j++) {
158205
for (int k = 0; k < FOUR; k++) {

backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
$if MAT2_IS_TRANSPOSED:
1414
#define MAT2_IS_TRANSPOSED
1515

16+
$if BATCH_MODE:
17+
#define BATCH_MODE
18+
1619
#include "indexing_utils.h"
1720
#include "matmul.h"
1821

@@ -41,27 +44,41 @@ void main() {
4144
return;
4245
}
4346

44-
FloatMatrix results = matmul_partial_4x4(
45-
im_mat1,
46-
im_mat2,
47-
pos,
48-
out_sizes[2],
49-
in_limits[0]);
47+
$if BATCH_MODE:
48+
FloatMatrix_3d results = matmul_partial_4x4x4(
49+
im_mat1,
50+
im_mat2,
51+
pos,
52+
out_sizes[2],
53+
in_limits[0]);
54+
$else:
55+
FloatMatrix_2d results = matmul_partial_4x4(
56+
im_mat1,
57+
im_mat2,
58+
pos,
59+
out_sizes[2],
60+
in_limits[0]);
5061

5162
for (int idx_c = 0; idx_c < FOUR; idx_c++) {
5263
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
5364
const ivec3 out_pos =
5465
ivec3(idx_r + FOUR * pos.x, idx_c + FOUR * pos.y, pos.z);
5566

5667
// results is in transposed order w.r.t. the desired output
57-
imageStore(
68+
$if BATCH_MODE:
69+
imageStore(
5870
im_out,
5971
out_pos,
6072
vec4(
6173
results.data[idx_c][idx_r][0],
6274
results.data[idx_c][idx_r][1],
6375
results.data[idx_c][idx_r][2],
6476
results.data[idx_c][idx_r][3]));
77+
$else:
78+
imageStore(
79+
im_out,
80+
out_pos,
81+
vec4(results.data[idx_c][idx_r], 0.0, 0.0, 0.0));
6582
}
6683
}
6784
}

backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ matmul_optimized:
1010
NDIM: 3
1111
PACKING: C_packed
1212
MAT2_IS_TRANSPOSED: false
13+
BATCH_MODE: false
1314
generate_variant_forall:
1415
DTYPE:
1516
- VALUE: float
@@ -18,3 +19,8 @@ matmul_optimized:
1819
- NAME: matmul_optimized
1920
- NAME: matmul_transposed_optimized
2021
MAT2_IS_TRANSPOSED: true
22+
- NAME: batch_matmul_optimized
23+
BATCH_MODE: true
24+
- NAME: batch_matmul_transposed_optimized
25+
MAT2_IS_TRANSPOSED: true
26+
BATCH_MODE: true

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ void add_addmm_optimized_node(
169169
std::string kernel_name = graph.get_bool(mat2_is_transposed)
170170
? "linear_optimized"
171171
: "addmm_optimized";
172+
173+
int mat1_dims = graph.sizes_of(mat1_W_packed).size();
174+
if (mat1_dims == 3) {
175+
kernel_name = "batch_" + kernel_name;
176+
}
177+
172178
add_dtype_suffix(kernel_name, graph.dtype_of(out));
173179

174180
graph.execute_nodes().emplace_back(new ExecuteNode(

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ void add_matmul_optimized_node(
134134
std::string kernel_name = mat2_is_transposed_val
135135
? "matmul_transposed_optimized"
136136
: "matmul_optimized";
137+
138+
int mat1_dims = graph.sizes_of(mat1_W_packed).size();
139+
if (mat1_dims == 3) {
140+
kernel_name = "batch_" + kernel_name;
141+
}
142+
137143
add_dtype_suffix(kernel_name, graph.dtype_of(out));
138144

139145
graph.execute_nodes().emplace_back(new ExecuteNode(

0 commit comments

Comments
 (0)