Skip to content

Commit 3e2cfc7

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
Integrate axis mapping into binary op (#5408)
Summary: Pull Request resolved: #5408 Update binary op to support axis mapped textures. Reviewed By: SS-JIA Differential Revision: D62622013 fbshipit-source-id: 070ce40e22f4fca4d438d8dd3a33887bee8fc78a
1 parent 618466e commit 3e2cfc7

File tree

5 files changed

+61
-25
lines changed

5 files changed

+61
-25
lines changed

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,45 @@
1919

2020
layout(std430) buffer;
2121

22-
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
23-
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
24-
${layout_declare_tensor(2, "r", "t_other", DTYPE, STORAGE)}
25-
${layout_declare_ubo(3, "ivec4", "out_sizes")}
26-
${layout_declare_ubo(4, "ivec4", "in_sizes")}
27-
${layout_declare_ubo(5, "ivec4", "other_sizes")}
28-
${layout_declare_ubo(6, "ivec2", "broadcast_params")}
29-
${layout_declare_ubo(7, "float", "alpha")}
22+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
23+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
24+
${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
25+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
26+
${layout_declare_ubo(B, "ivec4", "out_axis_map")}
27+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
28+
${layout_declare_ubo(B, "ivec4", "in_axis_map")}
29+
${layout_declare_ubo(B, "ivec4", "other_sizes")}
30+
${layout_declare_ubo(B, "ivec4", "other_axis_map")}
31+
${layout_declare_ubo(B, "ivec2", "broadcast_params")}
32+
${layout_declare_ubo(B, "float", "alpha")}
3033

3134
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3235

3336
layout(constant_id = 3) const int packed_dim = C_DIM;
3437

3538
void main() {
39+
// pos is physical (x, y, z), as global workgroup uses image extents
3640
const ivec3 pos = ivec3(gl_GlobalInvocationID);
37-
const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim);
41+
// physical pos (x, y, z) -> logical (w, c, h, n) output
42+
const ivec4 idx = to_tensor_idx(pos, out_sizes, out_axis_map, packed_dim);
3843

3944
if (any(greaterThanEqual(idx, out_sizes))) {
4045
return;
4146
}
4247

48+
// broadcast on logical sizes
4349
ivec4 in_idx = broadcast_indices(idx, in_sizes);
44-
VEC4_T in_texel = VEC4_T(texelFetch(
50+
VEC4_T in_texel = VEC4_T(load_texel(
4551
t_in,
46-
to_texture_pos(in_idx, in_sizes, packed_dim),
47-
0));
52+
// read axis mapped texel
53+
to_texture_pos(in_idx, in_sizes, in_axis_map, packed_dim)));
4854

55+
// broadcast on logical sizes
4956
ivec4 other_idx = broadcast_indices(idx, other_sizes);
50-
VEC4_T other_texel = VEC4_T(texelFetch(
57+
VEC4_T other_texel = VEC4_T(load_texel(
5158
t_other,
52-
to_texture_pos(other_idx, other_sizes, packed_dim),
53-
0));
59+
// read axis mapped texel
60+
to_texture_pos(other_idx, other_sizes, other_axis_map, packed_dim)));
5461

5562
// Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment.
5663
if (broadcast_params.x > 0) {
@@ -60,5 +67,7 @@ void main() {
6067
other_texel = other_texel.xxxx;
6168
}
6269

63-
imageStore(t_out, pos, VEC4_T(op(in_texel, other_texel, alpha)));
70+
imageStore(t_out,
71+
to_texture_pos(idx, out_sizes, out_axis_map, packed_dim),
72+
VEC4_T(op(in_texel, other_texel, alpha)));
6473
}

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,11 @@ void add_binary_op_node(
8585
{{arg1, arg2}, vkapi::MemoryAccessType::READ}},
8686
// Shader params buffers
8787
{t_out->sizes_ubo(),
88+
t_out->axis_map_ubo(),
8889
t_in1->sizes_ubo(),
90+
t_in1->axis_map_ubo(),
8991
t_in2->sizes_ubo(),
92+
t_in2->axis_map_ubo(),
9093
graph.create_params_buffer(broadcast_params),
9194
graph.create_params_buffer(alpha_val)},
9295
// Specialization Constants

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def get_binary_elementwise_inputs():
4949
((S, S1, S2), (S, S1, S2)),
5050
((S, S1, S2), (S, S1, 1), 2.0),
5151
((S, S1, S2), (S, 1, S2), 2.0),
52+
((XS, S, S1, S2), (XS, S, 1, 1), 2.0),
5253
]
5354
)
5455
test_suite.layouts = [

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,16 @@ def forward(self, x, y, w):
204204

205205
self.lower_module_and_test_output(add_module, sample_inputs)
206206

207+
sample_inputs = (
208+
torch.rand(size=(4, 5, 2, 3), dtype=torch.float32),
209+
torch.rand(size=(4, 5, 2, 3), dtype=torch.float32),
210+
torch.rand(
211+
size=(2, 3), dtype=torch.float32
212+
), # test broadcasting on packed dim
213+
)
214+
215+
self.lower_module_and_test_output(add_module, sample_inputs)
216+
207217
def test_vulkan_backend_add_int(self):
208218
class AddIntModule(torch.nn.Module):
209219
def __init__(self):

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,7 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) {
13991399
TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14001400
GraphConfig config;
14011401
ComputeGraph graph(config);
1402+
size_t expected_vma_allocation_count = 0;
14021403

14031404
std::vector<int64_t> size_big = {12, 64, 64};
14041405
std::vector<int64_t> size_small = {12, 64, 64};
@@ -1417,7 +1418,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14171418
// +2: t.sizes_ubo() for each staging shader
14181419
// +2: t.axis_map_ubo() for each staging shader
14191420
// +2: staging buffer for each input tensor
1420-
EXPECT_TRUE(get_vma_allocation_count() == 6);
1421+
expected_vma_allocation_count += 6;
1422+
EXPECT_EQ(get_vma_allocation_count(), expected_vma_allocation_count);
14211423

14221424
ValueRef c = graph.add_tensor(
14231425
size_big,
@@ -1427,16 +1429,22 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14271429
auto addFn = VK_GET_OP_FN("aten.add.Tensor");
14281430
addFn(graph, {a.value, b.value, kDummyValueRef, c});
14291431

1432+
// +2: alpha UBO, broadcast UBO for arithmetic shader
1433+
// +1: t.sizes_ubo() for arithmetic shader output c
1434+
// +1: t.axis_map_ubo() for arithmetic shader output c
1435+
expected_vma_allocation_count += 4;
1436+
EXPECT_EQ(get_vma_allocation_count(), expected_vma_allocation_count);
1437+
14301438
IOValueRef d = graph.add_input_tensor(
14311439
size_small,
14321440
vkapi::kFloat,
14331441
/*shared_object_idx = */ 2);
14341442

1435-
// +2: alpha UBO, broadcast UBO for arithmetic shader
14361443
// +1: t.sizes_ubo() uniform buffer for staging shader
14371444
// +1: t.axis_map_ubo() uniform buffer for staging shader
14381445
// +1: staging buffer for the input tensor
1439-
EXPECT_TRUE(get_vma_allocation_count() == 12);
1446+
expected_vma_allocation_count += 3;
1447+
EXPECT_EQ(get_vma_allocation_count(), expected_vma_allocation_count);
14401448

14411449
ValueRef e = graph.add_tensor(
14421450
size_big,
@@ -1446,21 +1454,26 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14461454
auto mulFn = VK_GET_OP_FN("aten.mul.Tensor");
14471455
mulFn(graph, {c, d.value, e});
14481456

1457+
// +2: alpha UBO, broadcast UBO for arithmetic shader
1458+
// +1: t.sizes_ubo() for arithmetic shader output e
1459+
// +1: t.axis_map_ubo() for arithmetic shader output e
1460+
expected_vma_allocation_count += 4;
1461+
EXPECT_EQ(get_vma_allocation_count(), expected_vma_allocation_count);
1462+
14491463
IOValueRef out = {};
14501464
out.value = e;
14511465
out.staging = graph.set_output_tensor(out.value);
14521466

1453-
// +2: alpha UBO, broadcast UBO for arithmetic shader
1454-
// +1: t.sizes_ubo() for staging shader
1455-
// +1: t.axis_map_ubo() for staging shader
1456-
// +1 staging buffer for the input tensor
1457-
EXPECT_TRUE(get_vma_allocation_count() == 17);
1467+
// +1: staging buffer for the output tensor
1468+
expected_vma_allocation_count += 1;
1469+
EXPECT_EQ(get_vma_allocation_count(), expected_vma_allocation_count);
14581470

14591471
graph.prepare();
14601472
graph.encode_execute();
14611473

14621474
// +3: shared memory allocations for tensors
1463-
EXPECT_TRUE(get_vma_allocation_count() == 20);
1475+
expected_vma_allocation_count += 3;
1476+
EXPECT_EQ(get_vma_allocation_count(), expected_vma_allocation_count);
14641477

14651478
// Run graph
14661479

0 commit comments

Comments
 (0)