Skip to content

Commit d2cd73d

Browse files
committed
Update base for Update on "[ET-VK] Allow clone op to transfer between memory layouts and storage types"
## Changes As title. Extend the functionality of the `aten.clone` operator to allow transitioning the storage type and memory layout between the input to the output tensor. ## Context This functionality will be used to transition input tensors to the optimal storage type and memory layout before entering the execution of an op. The transition nodes will be added by a memory metadata tagging pass that will be introduced in a subsequent diff. Differential Revision: [D65277710](https://our.internmc.facebook.com/intern/diff/D65277710/) [ghstack-poisoned]
1 parent 77fe041 commit d2cd73d

File tree

6 files changed

+35
-256
lines changed

6 files changed

+35
-256
lines changed

extension/llm/custom_ops/targets.bzl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2-
load(
3-
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
4-
"get_compiler_optimization_flags",
5-
)
6-
72

83
def define_common_targets():
94
"""Defines targets that should be shared between fbcode and xplat.
@@ -39,7 +34,7 @@ def define_common_targets():
3934
"//executorch/kernels/portable/cpu/util:reduce_util",
4035
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
4136
],
42-
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(),
37+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
4338
visibility = [
4439
"//executorch/...",
4540
"//executorch/extension/llm/custom_ops/...",

kernels/optimized/cpu/binary_ops.h

Lines changed: 2 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -41,62 +41,10 @@ enum class ElementwiseOptimizedPath {
4141
kTreatAs1d,
4242
kBroadcast2dBy1d,
4343
kBroadcast2dBy1dReverseArguments,
44-
kBroadcastNdByNd,
45-
kBroadcastNdByNdReverseArguments,
4644
};
4745

4846
namespace internal {
49-
50-
// Find the single broadcast dimension if it exists.
51-
// This path aims to handle broadcast of the following form
52-
// A = [a1, a2,., 1, .., an]
53-
// B = [b1, b2,., bm, .., bn]
54-
// OR
55-
// A = [a1, a2,., am, .., an]
56-
// B = [b1, b2,., 1, .., bn]
57-
int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) {
58-
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
59-
auto lhs_end = lhs.sizes().end();
60-
61-
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
62-
auto rhs_end = rhs.sizes().end();
63-
64-
const auto lhs_size = lhs_end - lhs_begin;
65-
const auto rhs_size = rhs_end - rhs_begin;
66-
67-
// Following example is not handled at the moment
68-
// [1, 3, 4, 5]
69-
// [2, 3, 4, 5]
70-
if (lhs_size != rhs_size) {
71-
return 0;
72-
}
73-
74-
int32_t broadcast_dim = 0;
75-
// Check
76-
// 1. if any dim value is 1 (it constitutes a broadcast dim)
77-
// 2. If more than one dim value is 1 (we cannot handle)
78-
// 3. If non-1 dim values are equal
79-
lhs_end--;
80-
rhs_end--;
81-
while (lhs_end != lhs_begin) {
82-
if (*lhs_end == 1 || *rhs_end == 1) {
83-
// If more than one broadcast dim is found, return 0.
84-
if (broadcast_dim != 0) {
85-
return 0;
86-
}
87-
// negative index is used
88-
broadcast_dim = lhs_end - lhs.sizes().end();
89-
} else if (*lhs_end != *rhs_end) {
90-
// If non-1 dim values are not equal, return 0.
91-
return 0;
92-
}
93-
lhs_end--;
94-
rhs_end--;
95-
}
96-
return broadcast_dim;
97-
}
98-
99-
inline ElementwiseOptimizedPath select_broadcast_optimized_path(
47+
inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
10048
const Tensor& lhs,
10149
const Tensor& rhs) {
10250
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
@@ -115,17 +63,6 @@ inline ElementwiseOptimizedPath select_broadcast_optimized_path(
11563
return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
11664
}
11765

118-
int32_t broadcast_dim = get_broadcast_dim(lhs, rhs);
119-
// Right now we dont handle last dim broadcast
120-
if (broadcast_dim < -1) {
121-
if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) {
122-
return x == 1;
123-
}) == 1) {
124-
return ElementwiseOptimizedPath::kBroadcastNdByNd;
125-
} else {
126-
return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
127-
}
128-
}
12966
return ElementwiseOptimizedPath::kNone;
13067
}
13168
} // namespace internal
@@ -148,28 +85,7 @@ ElementwiseOptimizedPath inline select_optimized_path(
14885
internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
14986
return ElementwiseOptimizedPath::kTreatAs1d;
15087
}
151-
return internal::select_broadcast_optimized_path(a, b);
152-
}
153-
154-
std::array<int32_t, 3> inline get_normalized_tensor_size(
155-
const Tensor& a,
156-
const int32_t broadcast_dim) {
157-
ET_CHECK_MSG(
158-
a.dim() > broadcast_dim,
159-
"Size of tensor: %zd, must be larger than broadcast_dim: %d",
160-
a.dim(),
161-
broadcast_dim);
162-
std::array<int32_t, 3> normalized_tensor_size;
163-
normalized_tensor_size[0] = 1;
164-
normalized_tensor_size[1] = a.size(broadcast_dim);
165-
normalized_tensor_size[2] = 1;
166-
for (size_t i = 0; i < broadcast_dim; i++) {
167-
normalized_tensor_size[0] *= a.size(i);
168-
}
169-
for (size_t i = broadcast_dim + 1; i < a.dim(); i++) {
170-
normalized_tensor_size[2] *= a.size(i);
171-
}
172-
return normalized_tensor_size;
88+
return internal::select_broadcast_2d_by_1d_optimized_path(a, b);
17389
}
17490

17591
} // namespace executor

kernels/optimized/cpu/op_mul.cpp

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,15 @@ Tensor& opt_mul_out(
130130
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
131131
const Tensor* lhs;
132132
const Tensor* rhs;
133-
if ((selected_optimized_path ==
134-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
135-
(selected_optimized_path ==
136-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
133+
if (selected_optimized_path ==
134+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
137135
lhs = &b;
138136
rhs = &a;
139137
} else {
140138
// Catch failure to update logic when adding new broadcasting possibility.
141139
ET_DCHECK(
142-
(selected_optimized_path ==
143-
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
144-
(selected_optimized_path ==
145-
ElementwiseOptimizedPath::kBroadcastNdByNd));
140+
selected_optimized_path ==
141+
ElementwiseOptimizedPath::kBroadcast2dBy1d);
146142
lhs = &a;
147143
rhs = &b;
148144
}
@@ -153,34 +149,15 @@ Tensor& opt_mul_out(
153149
InvalidArgument,
154150
out,
155151
"Failed to resize output tensor.");
156-
int64_t outer_size = 1;
157-
int64_t broadcast_size;
158-
int64_t inner_size;
159-
if ((selected_optimized_path ==
160-
ElementwiseOptimizedPath::kBroadcastNdByNd) ||
161-
(selected_optimized_path ==
162-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
163-
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
164-
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
165-
auto normalized_tensor_size_lhs =
166-
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
167-
outer_size = normalized_tensor_size_lhs[0];
168-
broadcast_size = normalized_tensor_size_lhs[1];
169-
inner_size = normalized_tensor_size_lhs[2];
170-
} else {
171-
broadcast_size = lhs->sizes()[lhs->dim() - 2];
172-
inner_size = lhs->sizes()[lhs->dim() - 1];
173-
}
174152
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
175153
using Vec = executorch::vec::Vectorized<CTYPE>;
176-
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
154+
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
177155
[](Vec x, Vec y) { return x * y; },
178156
out.mutable_data_ptr<CTYPE>(),
179157
lhs->const_data_ptr<CTYPE>(),
180158
rhs->const_data_ptr<CTYPE>(),
181-
outer_size,
182-
broadcast_size,
183-
inner_size);
159+
lhs->sizes()[lhs->dim() - 2],
160+
lhs->sizes()[lhs->dim() - 1]);
184161
});
185162
} else {
186163
ScalarType common_type =

kernels/optimized/vec/functional_base.h

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -326,49 +326,10 @@ inline void map4(
326326
}
327327

328328

329-
// This function implements broadcasting binary operation on two tensors
330-
// where lhs tensor is treated to be of shape [outer_size, broadcast_size, inner_size]
331-
// and rhs tensor is treated to be of shape [outer_size, 1, inner_size]
332-
// And this 1st dimension is considered broadcasting dimension
333-
// This formula can map broadcasting on any dim=broadcast_dim
334-
// for any two N dimensional tensors, where 0 < braodcast_dim < N-1
335-
template <typename scalar_t, typename Op>
336-
inline void broadcasting_map_3d_and_unsqueezed_3d(
337-
const Op& vec_fun,
338-
scalar_t* output_data,
339-
const scalar_t* lhs,
340-
const scalar_t* rhs,
341-
int64_t outer_size,
342-
int64_t broadcast_size,
343-
int64_t inner_size) {
344-
using Vec = vec::Vectorized<scalar_t>;
345-
int64_t outer_stride_lhs = inner_size * broadcast_size;
346-
int64_t outer_stride_rhs = inner_size;
347-
int64_t broadcast_stride_lhs = inner_size;
348-
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
349-
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
350-
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
351-
const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs;
352-
for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) {
353-
const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs;
354-
scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs;
355-
int64_t inner_idx = 0;
356-
for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) {
357-
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx);
358-
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx);
359-
Vec output_vec = vec_fun(data_vec, data_vec2);
360-
output_vec.store(output_data_row_2 + inner_idx);
361-
}
362-
if (inner_size - inner_idx > 0) {
363-
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx);
364-
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx);
365-
Vec output_vec = vec_fun(data_vec, data_vec2);
366-
output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx);
367-
}
368-
}
369-
}
370-
}
371-
329+
// Map vec_fun across input_data and input_data2, where input_data is
330+
// a two-dimensional array of size (size, size2), input_data2 is a
331+
// one-dimensional array of size size2, and input_data2 is broadcast
332+
// to be of size (size, size2).
372333
template <typename scalar_t, typename Op>
373334
inline void broadcasting_map_2d_by_1d(
374335
const Op& vec_fun,
@@ -377,8 +338,27 @@ inline void broadcasting_map_2d_by_1d(
377338
const scalar_t* input_data2,
378339
int64_t size,
379340
int64_t size2) {
380-
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
341+
using Vec = vec::Vectorized<scalar_t>;
342+
for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) {
343+
const scalar_t* input_data_row = input_data + outer_idx * size2;
344+
scalar_t* output_data_row = output_data + outer_idx * size2;
345+
int64_t inner_idx = 0;
346+
for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) {
347+
Vec data_vec = Vec::loadu(input_data_row + inner_idx);
348+
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx);
349+
Vec output_vec = vec_fun(data_vec, data_vec2);
350+
output_vec.store(output_data_row + inner_idx);
351+
}
352+
if (size2 - inner_idx > 0) {
353+
Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx);
354+
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx);
355+
Vec output_vec = vec_fun(data_vec, data_vec2);
356+
output_vec.store(output_data_row + inner_idx, size2 - inner_idx);
357+
}
358+
}
381359
}
382360

361+
362+
383363
} // namespace vec
384364
} // namespace executorch

kernels/test/op_mul_test.cpp

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -153,73 +153,6 @@ class OpMulOutTest : public OperatorTest {
153153
}
154154
}
155155

156-
template <ScalarType DTYPE>
157-
void test_broadcast_3D() {
158-
TensorFactory<DTYPE> tf_a;
159-
160-
Tensor a =
161-
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
162-
Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7});
163-
164-
// Destination for output of mul.
165-
Tensor out =
166-
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
167-
Tensor expected = tf_a.make(
168-
{2, 2, 3}, /*data=*/{2, 6, 12, 8, 15, 24, 35, 48, 63, 50, 66, 84});
169-
170-
// Check that it matches the expected output.
171-
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
172-
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
173-
}
174-
175-
template <ScalarType DTYPE>
176-
void test_broadcast_4D() {
177-
TensorFactory<DTYPE> tf_a;
178-
179-
Tensor a = tf_a.make(
180-
{2, 2, 3, 5},
181-
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
182-
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
183-
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
184-
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
185-
Tensor b = tf_a.make(
186-
{2, 1, 3, 5},
187-
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
188-
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
189-
190-
// Destination for output of mul.
191-
Tensor out = tf_a.zeros({2, 2, 3, 5});
192-
Tensor expected = tf_a.make(
193-
{2, 2, 3, 5},
194-
/*data=*/{1, 4, 9, 16, 25, 36, 49, 64, 81, 100,
195-
121, 144, 169, 196, 225, 16, 34, 54, 76, 100,
196-
126, 154, 184, 216, 250, 286, 324, 364, 406, 450,
197-
496, 544, 594, 646, 700, 756, 814, 874, 936, 1000,
198-
1066, 1134, 1204, 1276, 1350, 736, 799, 864, 931, 1000,
199-
1071, 1144, 1219, 1296, 1375, 1456, 1539, 1624, 1711, 1800});
200-
201-
// Check that it matches the expected output.
202-
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
203-
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
204-
205-
b = tf_a.make(
206-
{2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
207-
11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
208-
out = tf_a.zeros({2, 2, 3, 5});
209-
expected = tf_a.make(
210-
{2, 2, 3, 5},
211-
/*data=*/{1, 4, 9, 16, 25, 6, 14, 24, 36, 50,
212-
11, 24, 39, 56, 75, 96, 119, 144, 171, 200,
213-
126, 154, 184, 216, 250, 156, 189, 224, 261, 300,
214-
341, 384, 429, 476, 525, 396, 444, 494, 546, 600,
215-
451, 504, 559, 616, 675, 736, 799, 864, 931, 1000,
216-
816, 884, 954, 1026, 1100, 896, 969, 1044, 1121, 1200});
217-
218-
// Check that it matches the expected output.
219-
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
220-
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
221-
}
222-
223156
template <ScalarType DTYPE>
224157
void test_broadcast_b2a() {
225158
TensorFactory<DTYPE> tf_a;
@@ -363,16 +296,6 @@ TEST_F(OpMulOutTest, BroadcastA2BTest) {
363296
test_broadcast_a2b<ScalarType::Int>();
364297
test_broadcast_a2b<ScalarType::Half>();
365298
test_broadcast_a2b<ScalarType::BFloat16>();
366-
367-
// Test 3D tensors
368-
test_broadcast_3D<ScalarType::Float>();
369-
test_broadcast_3D<ScalarType::Half>();
370-
test_broadcast_3D<ScalarType::BFloat16>();
371-
372-
// Test 4D tensors
373-
test_broadcast_4D<ScalarType::Float>();
374-
test_broadcast_4D<ScalarType::Half>();
375-
test_broadcast_4D<ScalarType::BFloat16>();
376299
}
377300

378301
// Broadcast tensor a's size to tensor b's size
@@ -382,18 +305,6 @@ TEST_F(OpMulOutTest, BroadcastB2ATest) {
382305
test_broadcast_b2a<ScalarType::BFloat16>();
383306
}
384307

385-
TEST_F(OpMulOutTest, BroadcastNDTest) {
386-
// Test 3D tensors
387-
test_broadcast_3D<ScalarType::Float>();
388-
test_broadcast_3D<ScalarType::Half>();
389-
test_broadcast_3D<ScalarType::BFloat16>();
390-
391-
// Test 4D tensors
392-
test_broadcast_4D<ScalarType::Float>();
393-
test_broadcast_4D<ScalarType::Half>();
394-
test_broadcast_4D<ScalarType::BFloat16>();
395-
}
396-
397308
// Broadcast tensor a and b's size to a new size c.
398309
TEST_F(OpMulOutTest, BroadcastAB2CTest) {
399310
TensorFactory<ScalarType::Int> tf_a;

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def define_op_library(name, deps, android_deps, aten_target, _allow_third_party_
150150
# library, and it blocks users like unit tests to use kernel
151151
# implementation directly. So we enable this for xplat only.
152152
["-fvisibility=hidden"] if is_xplat() else []
153-
) + get_compiler_optimization_flags(),
153+
),
154154
deps = [
155155
"//executorch/runtime/kernel:kernel_includes" + aten_suffix,
156156
] + deps,

0 commit comments

Comments
 (0)