Skip to content

Commit 7496923

Browse files
authored
Modifying permute op to support all tensor packing.
Differential Revision: D70587814 Pull Request resolved: #9215
1 parent ec3ea96 commit 7496923

File tree

4 files changed

+60
-54
lines changed

4 files changed

+60
-54
lines changed

backends/vulkan/op_registry.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,9 +522,6 @@ def register_view_op(features: OpFeatures):
522522
@update_features(
523523
[
524524
# Shape Manipulation
525-
exir_ops.edge.aten.squeeze_copy.dims,
526-
exir_ops.edge.aten.unsqueeze_copy.default,
527-
exir_ops.edge.aten.permute_copy.default,
528525
exir_ops.edge.aten.t_copy.default,
529526
# Indexing and lookup
530527
exir_ops.edge.aten.flip.default,
@@ -556,10 +553,15 @@ def register_ported_op(features: OpFeatures):
556553
return features
557554

558555

556+
# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry becasue they support all packed dimensions
559557
@update_features(
560558
[
561559
# Indexing and lookup
562560
exir_ops.edge.aten.slice_copy.Tensor,
561+
# Shape Manipulation
562+
exir_ops.edge.aten.squeeze_copy.dims,
563+
exir_ops.edge.aten.unsqueeze_copy.default,
564+
exir_ops.edge.aten.permute_copy.default,
563565
]
564566
)
565567
def register_ported_op_all_packed_dims(features: OpFeatures):

backends/vulkan/runtime/graph/ops/glsl/permute.glsl

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,56 +21,61 @@ layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_i
2121

2222
layout(push_constant) uniform PRECISION restrict Block {
2323
ivec4 out_limits;
24-
ivec4 sizes;
24+
ivec4 in_sizes;
2525
// output dims
2626
ivec4 out_ndims;
2727
// x = output channels aligned to 4, y = input channels aligned to 4
28-
ivec2 ch_info;
28+
ivec2 channel_info;
2929
};
3030

3131
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
32+
layout(constant_id = 3) const int packed_dim = C_DIM;
3233

3334
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
3435

3536
void main() {
36-
const u16vec3 pos = u16vec3(gl_GlobalInvocationID);
37+
u16vec3 pos = u16vec3(gl_GlobalInvocationID);
3738

3839
if (any(greaterThanEqual(pos, out_limits.xyz))) {
3940
return;
4041
}
4142

42-
const int out_channel_4up = int(ch_info.x);
43-
const int in_channel_4up = int(ch_info.y);
44-
const int out_batch = int(sizes[3]);
4543
VEC4_T outval = VEC4_T(0.0);
46-
ivec4 v = ivec4(0); // holds b,c,h,w
4744

48-
v[out_ndims[2]] = pos.y;
49-
v[out_ndims[3]] = pos.x;
45+
// scale up output position's packed dim
46+
pos[packed_dim] <<= 2;
5047

51-
const int dst_index = pos.z << 2;
52-
int dst_out_index = dst_index / out_channel_4up;
53-
int dst_out_lane = dst_index % out_channel_4up;
48+
// index of packed dim in bchw format
49+
const int in_packed_dim_bchw_index = 3 - packed_dim;
5450

55-
for (int j = 0; j < 4; ++j, ++dst_out_lane) {
56-
if (dst_out_index >= out_batch) {
57-
// out of range
51+
for (int j = 0; j < 4; ++j, pos[packed_dim]++) {
52+
ivec4 in_bchw_pos = ivec4(0); // holds b,c,h,w
53+
// determine input position based on output position and permute map
54+
// out_ndims is in BCHW format
55+
in_bchw_pos[out_ndims[0]] = (pos.z / channel_info.x);
56+
in_bchw_pos[out_ndims[1]] = (pos.z % channel_info.x);
57+
in_bchw_pos[out_ndims[2]] = pos.y;
58+
in_bchw_pos[out_ndims[3]] = pos.x;
59+
60+
if (any(greaterThanEqual(in_bchw_pos.wzyx, in_sizes.xyzw))) {
5861
break;
5962
}
6063

61-
if (dst_out_lane == out_channel_4up) {
62-
dst_out_lane = 0;
63-
dst_out_index++;
64-
}
64+
// input tensor's packed dim pos (in xyz format) corresponding to output tensor's pos (which is also in xyz format)
65+
const int in_packed_dim_pos = in_bchw_pos[in_packed_dim_bchw_index];
6566

66-
v[out_ndims[0]] = dst_out_index;
67-
v[out_ndims[1]] = dst_out_lane;
67+
// calculate input position in y axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
68+
in_bchw_pos.y = in_bchw_pos.y + in_bchw_pos.x * channel_info.y;
6869

69-
int src_index = v[0] * in_channel_4up + v[1];
70+
// scale down input tensor's packed dim pos to perform fetch
71+
in_bchw_pos[in_packed_dim_bchw_index] >>= 2;
7072

71-
VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(v[3], v[2], src_index >> 2), 0));
72-
outval[j] = inval[src_index & 0x3];
73+
// fetch input texel
74+
VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(in_bchw_pos.wzy), 0));
75+
outval[j] = inval[in_packed_dim_pos & 0x3];
7376
}
7477

78+
pos[packed_dim] = uint16_t(gl_GlobalInvocationID[packed_dim]);
79+
7580
imageStore(image_out, pos, outval);
7681
}

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ void check_args(
2828
const api::vTensor& in,
2929
const std::vector<int64_t>& permute_dims,
3030
const api::vTensor& out) {
31-
VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
32-
VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
31+
VK_CHECK_COND(check_same_packed_dim(in, out));
3332

3433
// This implementation doesn't not requires the input tensor to have the same
3534
// dim size as the argument. The code will work as long as the input tensor's
@@ -72,10 +71,14 @@ void add_permute_node(
7271
int32_t out_channels = dim_at<kChannel4D>(t_out->sizes());
7372
int32_t in_channels = dim_at<kChannel4D>(t_in->sizes());
7473

75-
int32_t out_c_aligned = utils::align_up_4(out_channels);
76-
int32_t in_c_aligned = utils::align_up_4(in_channels);
74+
const auto packed_dim = graph.packed_dim_of(in);
75+
ivec2 channel_info = {out_channels, in_channels};
76+
if (packed_dim == WHCN::kChannelsDim) {
77+
channel_info[0] = utils::align_up_4(channel_info[0]);
78+
channel_info[1] = utils::align_up_4(channel_info[1]);
79+
}
7780

78-
const ivec2 ch_info = {out_c_aligned, in_c_aligned};
81+
const vkapi::SpecVarList spec_vars = {packed_dim};
7982

8083
graph.execute_nodes().emplace_back(new DispatchNode(
8184
graph,
@@ -86,14 +89,14 @@ void add_permute_node(
8689
{in, vkapi::MemoryAccessType::READ}},
8790
{},
8891
// Specialization Constants
89-
{},
92+
spec_vars,
9093
// Resizing Logic
9194
nullptr,
9295
{},
9396
{{graph.logical_limits_pc_of(out),
94-
graph.sizes_pc_of(out),
97+
graph.sizes_pc_of(in),
9598
PushConstantDataInfo(&out_dims, sizeof(out_dims)),
96-
PushConstantDataInfo(&ch_info, sizeof(ch_info))}}));
99+
PushConstantDataInfo(&channel_info, sizeof(channel_info))}}));
97100
}
98101

99102
void add_permute_node(

backends/vulkan/test/op_tests/cases.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8+
import itertools
9+
810
from collections import namedtuple
911
from typing import Callable
1012

@@ -457,26 +459,20 @@ def get_select_int_inputs():
457459

458460
@register_test_suite(["aten.permute.default", "aten.permute_copy.default"])
459461
def get_permute_inputs():
460-
test_suite = VkTestSuite(
461-
[
462-
((9, 2, 9, 4), [0, 1, 2, 3]),
463-
((9, 2, 9, 4), [0, 1, 3, 2]),
464-
((9, 2, 9, 4), [0, 2, 1, 3]),
465-
((9, 2, 9, 4), [0, 2, 3, 1]),
466-
((9, 2, 9, 4), [0, 3, 1, 2]),
467-
((9, 2, 9, 4), [0, 3, 2, 1]),
468-
((9, 2, 9, 4), [3, 0, 1, 2]),
469-
((9, 2, 9, 4), [3, 2, 0, 1]),
470-
((9, 2, 9, 4), [2, 3, 0, 1]),
471-
((9, 2, 9, 4), [2, 0, 3, 1]),
472-
((9, 2, 9), [2, 0, 1]),
473-
((9, 2, 9), [1, 2, 0]),
474-
((9, 2), [0, 1]),
475-
((9, 2), [1, 0]),
476-
]
477-
)
462+
batch_tests = [
463+
((9, 2, 5, 7), out_axis) for out_axis in itertools.permutations([0, 1, 2, 3])
464+
]
465+
channel_tests = [
466+
((9, 2, 5), out_axis) for out_axis in itertools.permutations([0, 1, 2])
467+
]
468+
wh_tests = [((9, 2), out_axis) for out_axis in itertools.permutations([0, 1])]
469+
test_suite = VkTestSuite(batch_tests + channel_tests + wh_tests)
478470

479-
test_suite.layouts = ["utils::kChannelsPacked"]
471+
test_suite.layouts = [
472+
"utils::kWidthPacked",
473+
"utils::kHeightPacked",
474+
"utils::kChannelsPacked",
475+
]
480476
return test_suite
481477

482478

0 commit comments

Comments
 (0)