Skip to content

Commit c5ff517

Browse files
committed
Update on "[ET-VK] Allow clone op to transfer between memory layouts and storage types"
## Changes As title. Extend the functionality of the `aten.clone` operator to allow transitioning the storage type and memory layout between the input to the output tensor. ## Context This functionality will be used to transition input tensors to the optimal storage type and memory layout before entering the execution of an op. The transition nodes will be added by a memory metadata tagging pass that will be introduced in a subsequent diff. Differential Revision: [D65277710](https://our.internmc.facebook.com/intern/diff/D65277710/) [ghstack-poisoned]
2 parents 7513dfa + d2cd73d commit c5ff517

File tree

11 files changed

+46
-263
lines changed

11 files changed

+46
-263
lines changed

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ image_to_nchw:
1919
- NAME: image_to_nchw_texture3d
2020
- NAME: image_to_nchw_texture2d
2121
STORAGE: texture2d
22-
- NAME: image_to_buffer
22+
- NAME: clone_image_to_buffer
2323
TO_STAGING: False

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ nchw_to_image:
1919
- NAME: nchw_to_image_texture3d
2020
- NAME: nchw_to_image_texture2d
2121
STORAGE: texture2d
22-
- NAME: buffer_to_image
22+
- NAME: clone_buffer_to_image
2323
FROM_STAGING: False

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ void resize_clone_node(
2525
(void)extra_args;
2626
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
2727
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
28-
out->virtual_resize(in->sizes());
28+
// TODO: support for when dimensionality doesn't match, i.e. clone is used to
29+
// implement squeeze.
30+
if (out->dim() == in->dim()) {
31+
out->virtual_resize(in->sizes());
32+
}
2933
}
3034

3135
void add_clone_node(
@@ -56,7 +60,7 @@ void add_image_to_buffer_node(
5660
ComputeGraph& graph,
5761
const ValueRef image,
5862
const ValueRef buffer) {
59-
std::string kernel_name = "image_to_buffer";
63+
std::string kernel_name = "clone_image_to_buffer";
6064
add_dtype_suffix(kernel_name, graph.dtype_of(image));
6165
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
6266

@@ -80,7 +84,7 @@ void add_buffer_to_image_node(
8084
ComputeGraph& graph,
8185
const ValueRef buffer,
8286
const ValueRef image) {
83-
std::string kernel_name = "buffer_to_image";
87+
std::string kernel_name = "clone_buffer_to_image";
8488
add_dtype_suffix(kernel_name, graph.dtype_of(image));
8589
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
8690

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader(
5656
if (is_bitw8(v_src.dtype()) && v_src.storage_type() != utils::kBuffer &&
5757
!int8_buffer_enabled) {
5858
kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
59-
add_dtype_suffix(kernel_name, v_src);
6059
add_storage_type_suffix(kernel_name, v_src);
60+
add_dtype_suffix(kernel_name, v_src);
6161
return VK_KERNEL_FROM_STR(kernel_name);
6262
}
6363

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ void record_bitw8_image_to_nchw_nobitw8buffer_op(
118118
utils::uvec3 global_wg_size = {buffer_len, 1, 1};
119119

120120
std::string kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
121-
add_dtype_suffix(kernel_name, v_src);
122121
add_storage_type_suffix(kernel_name, v_src);
122+
add_dtype_suffix(kernel_name, v_src);
123123

124124
context->submit_compute_job(
125125
VK_KERNEL_FROM_STR(kernel_name),

extension/llm/custom_ops/targets.bzl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2-
load(
3-
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
4-
"get_compiler_optimization_flags",
5-
)
6-
72

83
def define_common_targets():
94
"""Defines targets that should be shared between fbcode and xplat.
@@ -39,7 +34,7 @@ def define_common_targets():
3934
"//executorch/kernels/portable/cpu/util:reduce_util",
4035
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
4136
],
42-
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(),
37+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
4338
visibility = [
4439
"//executorch/...",
4540
"//executorch/extension/llm/custom_ops/...",

kernels/optimized/cpu/binary_ops.h

Lines changed: 2 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -41,62 +41,10 @@ enum class ElementwiseOptimizedPath {
4141
kTreatAs1d,
4242
kBroadcast2dBy1d,
4343
kBroadcast2dBy1dReverseArguments,
44-
kBroadcastNdByNd,
45-
kBroadcastNdByNdReverseArguments,
4644
};
4745

4846
namespace internal {
49-
50-
// Find the single broadcast dimension if it exists.
51-
// This path aims to handle broadcast of the following form
52-
// A = [a1, a2,., 1, .., an]
53-
// B = [b1, b2,., bm, .., bn]
54-
// OR
55-
// A = [a1, a2,., am, .., an]
56-
// B = [b1, b2,., 1, .., bn]
57-
int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) {
58-
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
59-
auto lhs_end = lhs.sizes().end();
60-
61-
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
62-
auto rhs_end = rhs.sizes().end();
63-
64-
const auto lhs_size = lhs_end - lhs_begin;
65-
const auto rhs_size = rhs_end - rhs_begin;
66-
67-
// Following example is not handled at the moment
68-
// [1, 3, 4, 5]
69-
// [2, 3, 4, 5]
70-
if (lhs_size != rhs_size) {
71-
return 0;
72-
}
73-
74-
int32_t broadcast_dim = 0;
75-
// Check
76-
// 1. if any dim value is 1 (it constitutes a broadcast dim)
77-
// 2. If more than one dim value is 1 (we cannot handle)
78-
// 3. If non-1 dim values are equal
79-
lhs_end--;
80-
rhs_end--;
81-
while (lhs_end != lhs_begin) {
82-
if (*lhs_end == 1 || *rhs_end == 1) {
83-
// If more than one broadcast dim is found, return 0.
84-
if (broadcast_dim != 0) {
85-
return 0;
86-
}
87-
// negative index is used
88-
broadcast_dim = lhs_end - lhs.sizes().end();
89-
} else if (*lhs_end != *rhs_end) {
90-
// If non-1 dim values are not equal, return 0.
91-
return 0;
92-
}
93-
lhs_end--;
94-
rhs_end--;
95-
}
96-
return broadcast_dim;
97-
}
98-
99-
inline ElementwiseOptimizedPath select_broadcast_optimized_path(
47+
inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
10048
const Tensor& lhs,
10149
const Tensor& rhs) {
10250
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
@@ -115,17 +63,6 @@ inline ElementwiseOptimizedPath select_broadcast_optimized_path(
11563
return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
11664
}
11765

118-
int32_t broadcast_dim = get_broadcast_dim(lhs, rhs);
119-
// Right now we dont handle last dim broadcast
120-
if (broadcast_dim < -1) {
121-
if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) {
122-
return x == 1;
123-
}) == 1) {
124-
return ElementwiseOptimizedPath::kBroadcastNdByNd;
125-
} else {
126-
return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
127-
}
128-
}
12966
return ElementwiseOptimizedPath::kNone;
13067
}
13168
} // namespace internal
@@ -148,28 +85,7 @@ ElementwiseOptimizedPath inline select_optimized_path(
14885
internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
14986
return ElementwiseOptimizedPath::kTreatAs1d;
15087
}
151-
return internal::select_broadcast_optimized_path(a, b);
152-
}
153-
154-
std::array<int32_t, 3> inline get_normalized_tensor_size(
155-
const Tensor& a,
156-
const int32_t broadcast_dim) {
157-
ET_CHECK_MSG(
158-
a.dim() > broadcast_dim,
159-
"Size of tensor: %zd, must be larger than broadcast_dim: %d",
160-
a.dim(),
161-
broadcast_dim);
162-
std::array<int32_t, 3> normalized_tensor_size;
163-
normalized_tensor_size[0] = 1;
164-
normalized_tensor_size[1] = a.size(broadcast_dim);
165-
normalized_tensor_size[2] = 1;
166-
for (size_t i = 0; i < broadcast_dim; i++) {
167-
normalized_tensor_size[0] *= a.size(i);
168-
}
169-
for (size_t i = broadcast_dim + 1; i < a.dim(); i++) {
170-
normalized_tensor_size[2] *= a.size(i);
171-
}
172-
return normalized_tensor_size;
88+
return internal::select_broadcast_2d_by_1d_optimized_path(a, b);
17389
}
17490

17591
} // namespace executor

kernels/optimized/cpu/op_mul.cpp

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,15 @@ Tensor& opt_mul_out(
130130
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
131131
const Tensor* lhs;
132132
const Tensor* rhs;
133-
if ((selected_optimized_path ==
134-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
135-
(selected_optimized_path ==
136-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
133+
if (selected_optimized_path ==
134+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
137135
lhs = &b;
138136
rhs = &a;
139137
} else {
140138
// Catch failure to update logic when adding new broadcasting possibility.
141139
ET_DCHECK(
142-
(selected_optimized_path ==
143-
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
144-
(selected_optimized_path ==
145-
ElementwiseOptimizedPath::kBroadcastNdByNd));
140+
selected_optimized_path ==
141+
ElementwiseOptimizedPath::kBroadcast2dBy1d);
146142
lhs = &a;
147143
rhs = &b;
148144
}
@@ -153,34 +149,15 @@ Tensor& opt_mul_out(
153149
InvalidArgument,
154150
out,
155151
"Failed to resize output tensor.");
156-
int64_t outer_size = 1;
157-
int64_t broadcast_size;
158-
int64_t inner_size;
159-
if ((selected_optimized_path ==
160-
ElementwiseOptimizedPath::kBroadcastNdByNd) ||
161-
(selected_optimized_path ==
162-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
163-
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
164-
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
165-
auto normalized_tensor_size_lhs =
166-
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
167-
outer_size = normalized_tensor_size_lhs[0];
168-
broadcast_size = normalized_tensor_size_lhs[1];
169-
inner_size = normalized_tensor_size_lhs[2];
170-
} else {
171-
broadcast_size = lhs->sizes()[lhs->dim() - 2];
172-
inner_size = lhs->sizes()[lhs->dim() - 1];
173-
}
174152
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
175153
using Vec = executorch::vec::Vectorized<CTYPE>;
176-
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
154+
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
177155
[](Vec x, Vec y) { return x * y; },
178156
out.mutable_data_ptr<CTYPE>(),
179157
lhs->const_data_ptr<CTYPE>(),
180158
rhs->const_data_ptr<CTYPE>(),
181-
outer_size,
182-
broadcast_size,
183-
inner_size);
159+
lhs->sizes()[lhs->dim() - 2],
160+
lhs->sizes()[lhs->dim() - 1]);
184161
});
185162
} else {
186163
ScalarType common_type =

kernels/optimized/vec/functional_base.h

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -326,49 +326,10 @@ inline void map4(
326326
}
327327

328328

329-
// This function implements broadcasting binary operation on two tensors
330-
// where lhs tensor is treated to be of shape [outer_size, broadcast_size, inner_size]
331-
// and rhs tensor is treated to be of shape [outer_size, 1, inner_size]
332-
// And this 1st dimension is considered broadcasting dimension
333-
// This formula can map broadcasting on any dim=broadcast_dim
334-
// for any two N dimensional tensors, where 0 < braodcast_dim < N-1
335-
template <typename scalar_t, typename Op>
336-
inline void broadcasting_map_3d_and_unsqueezed_3d(
337-
const Op& vec_fun,
338-
scalar_t* output_data,
339-
const scalar_t* lhs,
340-
const scalar_t* rhs,
341-
int64_t outer_size,
342-
int64_t broadcast_size,
343-
int64_t inner_size) {
344-
using Vec = vec::Vectorized<scalar_t>;
345-
int64_t outer_stride_lhs = inner_size * broadcast_size;
346-
int64_t outer_stride_rhs = inner_size;
347-
int64_t broadcast_stride_lhs = inner_size;
348-
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
349-
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
350-
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
351-
const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs;
352-
for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) {
353-
const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs;
354-
scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs;
355-
int64_t inner_idx = 0;
356-
for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) {
357-
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx);
358-
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx);
359-
Vec output_vec = vec_fun(data_vec, data_vec2);
360-
output_vec.store(output_data_row_2 + inner_idx);
361-
}
362-
if (inner_size - inner_idx > 0) {
363-
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx);
364-
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx);
365-
Vec output_vec = vec_fun(data_vec, data_vec2);
366-
output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx);
367-
}
368-
}
369-
}
370-
}
371-
329+
// Map vec_fun across input_data and input_data2, where input_data is
330+
// a two-dimensional array of size (size, size2), input_data2 is a
331+
// one-dimensional array of size size2, and input_data2 is broadcast
332+
// to be of size (size, size2).
372333
template <typename scalar_t, typename Op>
373334
inline void broadcasting_map_2d_by_1d(
374335
const Op& vec_fun,
@@ -377,8 +338,27 @@ inline void broadcasting_map_2d_by_1d(
377338
const scalar_t* input_data2,
378339
int64_t size,
379340
int64_t size2) {
380-
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
341+
using Vec = vec::Vectorized<scalar_t>;
342+
for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) {
343+
const scalar_t* input_data_row = input_data + outer_idx * size2;
344+
scalar_t* output_data_row = output_data + outer_idx * size2;
345+
int64_t inner_idx = 0;
346+
for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) {
347+
Vec data_vec = Vec::loadu(input_data_row + inner_idx);
348+
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx);
349+
Vec output_vec = vec_fun(data_vec, data_vec2);
350+
output_vec.store(output_data_row + inner_idx);
351+
}
352+
if (size2 - inner_idx > 0) {
353+
Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx);
354+
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx);
355+
Vec output_vec = vec_fun(data_vec, data_vec2);
356+
output_vec.store(output_data_row + inner_idx, size2 - inner_idx);
357+
}
358+
}
381359
}
382360

361+
362+
383363
} // namespace vec
384364
} // namespace executorch

0 commit comments

Comments
 (0)