@@ -25,10 +25,12 @@ using utils::uvec4;
25
25
namespace {
26
26
27
27
void check_args (
28
- const api::vTensor& in,
29
- const std::vector<int64_t >& permute_dims,
30
- const api::vTensor& out) {
31
- VK_CHECK_COND (check_same_packed_dim (in, out));
28
+ ComputeGraph& graph,
29
+ const ValueRef in,
30
+ const ValueRef permute_dims,
31
+ const ValueRef out) {
32
+ (void )permute_dims;
33
+ VK_CHECK_COND (check_same_packed_dim (graph, in, out));
32
34
33
35
// This implementation doesn't not requires the input tensor to have the same
34
36
// dim size as the argument. The code will work as long as the input tensor's
@@ -38,40 +40,94 @@ void check_args(
38
40
39
41
} // namespace
40
42
43
+ void resize_permute_node (
44
+ ComputeGraph* graph,
45
+ const std::vector<ArgGroup>& args,
46
+ const std::vector<ValueRef>& resize_args) {
47
+ const ValueRef out = args[0 ].refs [0 ];
48
+ const ValueRef in = args[1 ].refs [0 ];
49
+
50
+ const std::vector<int64_t > in_sizes = graph->sizes_of (in);
51
+ const std::vector<int64_t > out_sizes = graph->sizes_of (out);
52
+
53
+ const std::vector<int64_t > permute_dims =
54
+ graph->extract_int_or_symint_list (resize_args[0 ]);
55
+
56
+ if (in_sizes.size () == out_sizes.size () &&
57
+ in_sizes.size () == permute_dims.size ()) {
58
+ std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
59
+ const int64_t out_ndim = std::max (in_sizes.size (), out_sizes.size ());
60
+ for (int i = 0 ; i < out_ndim; i++) {
61
+ const int64_t permute_dim = permute_dims.at (i);
62
+ new_out_sizes.at (i) = in_sizes.at (permute_dim);
63
+ }
64
+ graph->virtual_resize (out, new_out_sizes);
65
+ }
66
+ // Case where permute is being used to implement squeeze
67
+ else if (
68
+ in_sizes.size () > out_sizes.size () &&
69
+ in_sizes.size () == permute_dims.size ()) {
70
+ std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
71
+ const size_t offset = in_sizes.size () - out_sizes.size ();
72
+ for (int i = 0 ; i < out_sizes.size (); i++) {
73
+ const int64_t permute_dim = permute_dims.at (i + offset);
74
+ new_out_sizes.at (i) = in_sizes.at (permute_dim);
75
+ }
76
+ graph->virtual_resize (out, new_out_sizes);
77
+ }
78
+ // Case where Permute is being used to implement unsqueeze
79
+ else if (
80
+ in_sizes.size () < out_sizes.size () &&
81
+ out_sizes.size () == permute_dims.size ()) {
82
+ std::vector<int64_t > new_out_sizes (out_sizes.size (), 1 );
83
+ const size_t offset = out_sizes.size () - in_sizes.size ();
84
+ for (int i = 0 ; i < out_sizes.size (); i++) {
85
+ int64_t permute_dim = permute_dims.at (i) - offset;
86
+ if (permute_dim >= 0 ) {
87
+ new_out_sizes.at (i) = in_sizes.at (permute_dim);
88
+ }
89
+ }
90
+ graph->virtual_resize (out, new_out_sizes);
91
+ } else {
92
+ VK_THROW (" Invalid permute dims" );
93
+ }
94
+ }
95
+
41
96
void add_permute_node (
42
97
ComputeGraph& graph,
43
- ValueRef in,
44
- const std::vector<int64_t >& permute_dims,
45
- ValueRef out) {
46
- vTensorPtr t_in = graph.get_tensor (in);
47
- vTensorPtr t_out = graph.get_tensor (out);
48
-
49
- check_args (*t_in, permute_dims, *t_out);
98
+ const ValueRef in,
99
+ const ValueRef permute_dims,
100
+ const ValueRef out) {
101
+ check_args (graph, in, permute_dims, out);
50
102
51
103
ivec4 out_dims{0 , 1 , 2 , 3 };
52
104
53
105
// Special cases of squeeze/unsqueeze. Because the input dim size can be
54
- // different with output dim size. So pick t_in->dim( ) if squeeze, and
55
- // t_out->dim( ) if unsqueeze to create parameter for permute.
56
- int64_t out_ndim = std::max (t_in-> dim ( ), t_out-> dim ( ));
106
+ // different with output dim size. So pick graph.dim_of(in ) if squeeze, and
107
+ // graph.dim_of(out ) if unsqueeze to create parameter for permute.
108
+ const int64_t out_ndim = std::max (graph. dim_of (in ), graph. dim_of (out ));
57
109
std::vector<bool > seen (out_ndim);
58
- for (int i = 0 ; i < out_ndim; i++) {
59
- int64_t permute_dim = permute_dims[i];
60
- VK_CHECK_COND (
61
- !seen[permute_dim], " Argument dim " , permute_dim, " is repeated" );
62
- seen[permute_dim] = true ;
63
-
64
- out_dims[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
110
+ {
111
+ IntListPtr permute_dims_ptr = graph.get_int_list (permute_dims);
112
+ for (int i = 0 ; i < out_ndim; i++) {
113
+ int64_t permute_dim = permute_dims_ptr->at (i);
114
+ VK_CHECK_COND (
115
+ !seen[permute_dim], " Argument dim " , permute_dim, " is repeated" );
116
+ seen[permute_dim] = true ;
117
+
118
+ out_dims[(4u - out_ndim) + i] =
119
+ utils::safe_downcast<int32_t >(permute_dim + (4 - out_ndim));
120
+ }
65
121
}
66
122
67
123
std::string kernel_name = " permute" ;
68
124
kernel_name.reserve (kShaderNameReserve );
69
- add_dtype_suffix (kernel_name, *t_out );
125
+ add_dtype_suffix (kernel_name, graph. dtype_of (out) );
70
126
71
- int32_t out_channels = dim_at<kChannel4D >(t_out-> sizes ( ));
72
- int32_t in_channels = dim_at<kChannel4D >(t_in-> sizes ( ));
127
+ const int32_t out_channels = dim_at<kChannel4D >(graph. sizes_of (out ));
128
+ const int32_t in_channels = dim_at<kChannel4D >(graph. sizes_of (in ));
73
129
74
- const auto packed_dim = graph.packed_dim_of (in);
130
+ const int32_t packed_dim = graph.packed_dim_of (in);
75
131
ivec2 channel_info = {out_channels, in_channels};
76
132
if (packed_dim == WHCN::kChannelsDim ) {
77
133
channel_info[0 ] = utils::align_up_4 (channel_info[0 ]);
@@ -95,19 +151,9 @@ void add_permute_node(
95
151
// Specialization Constants
96
152
spec_vars,
97
153
// Resize Args
98
- {},
154
+ {permute_dims },
99
155
// Resizing Logic
100
- nullptr ));
101
- }
102
-
103
- void add_permute_node (
104
- ComputeGraph& graph,
105
- ValueRef in,
106
- ValueRef permute_dims_ref,
107
- ValueRef out) {
108
- IntListPtr permute_dims = graph.get_int_list (permute_dims_ref);
109
-
110
- add_permute_node (graph, in, *permute_dims, out);
156
+ resize_permute_node));
111
157
}
112
158
113
159
void permute (ComputeGraph& graph, const std::vector<ValueRef>& args) {
0 commit comments