Skip to content

Commit 7036fd9

Browse files
committed
Handle broadcast semantics for last dim
Differential Revision: [D64156863](https://our.internmc.facebook.com/intern/diff/D64156863/) ghstack-source-id: 248160686 Pull Request resolved: #6240
1 parent a12c3f7 commit 7036fd9

File tree

4 files changed

+151
-50
lines changed

4 files changed

+151
-50
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ enum class ElementwiseOptimizedPath {
4343
kBroadcast2dBy1dReverseArguments,
4444
kBroadcastNdByNd,
4545
kBroadcastNdByNdReverseArguments,
46+
kBroadcastLastDim,
47+
kBroadcastLastDimReverseArguments,
4648
};
4749

4850
namespace internal {
@@ -117,6 +119,12 @@ inline ElementwiseOptimizedPath select_broadcast_optimized_path(
117119
} else {
118120
return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
119121
}
122+
} else if (broadcast_dim == -1) {
123+
if (std::count_if(lhs_begin, lhs_end, [](Tensor::SizesType x) { return x == 1; }) == 1) {
124+
return ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments;
125+
} else {
126+
return ElementwiseOptimizedPath::kBroadcastLastDim;
127+
}
120128
}
121129
return ElementwiseOptimizedPath::kNone;
122130
}

kernels/optimized/cpu/op_mul.cpp

Lines changed: 113 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/kernels/optimized/vec/vec.h>
1212
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1313
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
1415
#include <executorch/runtime/kernel/kernel_includes.h>
1516
#include <executorch/runtime/platform/assert.h>
1617

@@ -66,6 +67,117 @@ template <
6667
typename CTYPE_OUT>
6768
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
6869
: public ReportCanCastBug {};
70+
71+
Tensor& handle_last_dim_broadcast(
72+
KernelRuntimeContext& ctx,
73+
const Tensor& a,
74+
const Tensor& b,
75+
Tensor& out,
76+
const ElementwiseOptimizedPath selected_optimized_path) {
77+
ScalarType out_type = out.scalar_type();
78+
const Tensor* lhs;
79+
const Tensor* rhs;
80+
if (selected_optimized_path ==
81+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
82+
lhs = &b;
83+
rhs = &a;
84+
} else {
85+
lhs = &a;
86+
rhs = &b;
87+
}
88+
auto error = resize_tensor(out, lhs->sizes());
89+
ET_KERNEL_CHECK_MSG(
90+
ctx,
91+
error == Error::Ok,
92+
InvalidArgument,
93+
out,
94+
"Failed to resize output tensor.");
95+
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
96+
const auto broadcast_size = out.size(out.dim() - 1);
97+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
98+
using Vec = executorch::vec::Vectorized<CTYPE>;
99+
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
100+
[](Vec x, Vec y) { return x * y; },
101+
out.mutable_data_ptr<CTYPE>(),
102+
lhs->const_data_ptr<CTYPE>(),
103+
rhs->const_data_ptr<CTYPE>(),
104+
outer_size,
105+
broadcast_size);
106+
});
107+
return out;
108+
}
109+
110+
Tensor& handle_broadcast_mul(
111+
KernelRuntimeContext& ctx,
112+
const Tensor& a,
113+
const Tensor& b,
114+
Tensor& out,
115+
const ElementwiseOptimizedPath selected_optimized_path) {
116+
117+
if ((selected_optimized_path ==
118+
ElementwiseOptimizedPath::kBroadcastLastDim) ||
119+
(selected_optimized_path == ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
120+
return handle_last_dim_broadcast(
121+
ctx, a, b, out, selected_optimized_path);
122+
}
123+
124+
ScalarType out_type = out.scalar_type();
125+
const Tensor* lhs;
126+
const Tensor* rhs;
127+
if ((selected_optimized_path ==
128+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
129+
(selected_optimized_path ==
130+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
131+
lhs = &b;
132+
rhs = &a;
133+
} else {
134+
// Catch failure to update logic when adding new broadcasting possibility.
135+
ET_DCHECK(
136+
(selected_optimized_path ==
137+
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
138+
(selected_optimized_path ==
139+
ElementwiseOptimizedPath::kBroadcastNdByNd));
140+
lhs = &a;
141+
rhs = &b;
142+
}
143+
auto error = resize_tensor(out, lhs->sizes());
144+
ET_KERNEL_CHECK_MSG(
145+
ctx,
146+
error == Error::Ok,
147+
InvalidArgument,
148+
out,
149+
"Failed to resize output tensor.");
150+
int64_t outer_size = 1;
151+
int64_t broadcast_size;
152+
int64_t inner_size;
153+
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
154+
(selected_optimized_path ==
155+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
156+
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
157+
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
158+
int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim;
159+
auto normalized_tensor_size_lhs =
160+
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
161+
outer_size = normalized_tensor_size_lhs[0];
162+
broadcast_size = normalized_tensor_size_lhs[1];
163+
inner_size = normalized_tensor_size_lhs[2];
164+
} else {
165+
broadcast_size = lhs->sizes()[lhs->dim() - 2];
166+
inner_size = lhs->sizes()[lhs->dim() - 1];
167+
}
168+
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
169+
using Vec = executorch::vec::Vectorized<CTYPE>;
170+
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
171+
[](Vec x, Vec y) { return x * y; },
172+
out.mutable_data_ptr<CTYPE>(),
173+
lhs->const_data_ptr<CTYPE>(),
174+
rhs->const_data_ptr<CTYPE>(),
175+
outer_size,
176+
broadcast_size,
177+
inner_size);
178+
});
179+
return out;
180+
}
69181
} // namespace
70182

71183
Tensor& opt_mul_out(
@@ -128,56 +240,7 @@ Tensor& opt_mul_out(
128240
out.numel());
129241
});
130242
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
131-
const Tensor* lhs;
132-
const Tensor* rhs;
133-
if ((selected_optimized_path ==
134-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
135-
(selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
136-
lhs = &b;
137-
rhs = &a;
138-
} else {
139-
// Catch failure to update logic when adding new broadcasting possibility.
140-
ET_DCHECK(
141-
(selected_optimized_path ==
142-
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
143-
(selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd));
144-
lhs = &a;
145-
rhs = &b;
146-
}
147-
auto error = resize_tensor(out, lhs->sizes());
148-
ET_KERNEL_CHECK_MSG(
149-
ctx,
150-
error == Error::Ok,
151-
InvalidArgument,
152-
out,
153-
"Failed to resize output tensor.");
154-
int64_t outer_size = 1;
155-
int64_t broadcast_size;
156-
int64_t inner_size;
157-
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
158-
(selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
159-
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
160-
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
161-
int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim;
162-
auto normalized_tensor_size_lhs = get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
163-
outer_size = normalized_tensor_size_lhs[0];
164-
broadcast_size = normalized_tensor_size_lhs[1];
165-
inner_size = normalized_tensor_size_lhs[2];
166-
} else {
167-
broadcast_size = lhs->sizes()[lhs->dim() - 2];
168-
inner_size = lhs->sizes()[lhs->dim() - 1];
169-
}
170-
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
171-
using Vec = executorch::vec::Vectorized<CTYPE>;
172-
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
173-
[](Vec x, Vec y) { return x * y; },
174-
out.mutable_data_ptr<CTYPE>(),
175-
lhs->const_data_ptr<CTYPE>(),
176-
rhs->const_data_ptr<CTYPE>(),
177-
outer_size,
178-
broadcast_size,
179-
inner_size);
180-
});
243+
return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path);
181244
} else {
182245
ScalarType common_type =
183246
promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ _OPTIMIZED_ATEN_OPS = (
7272
":binary_ops",
7373
"//executorch/kernels/portable/cpu:scalar_utils",
7474
"//executorch/kernels/portable/cpu/util:broadcast_util",
75+
"//executorch/runtime/core/exec_aten/util:tensor_util",
7576
],
7677
),
7778
op_target(

kernels/optimized/vec/functional_base.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,5 +378,34 @@ inline void broadcasting_map_2d_by_1d(
378378
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
379379
}
380380

381+
template <typename scalar_t, typename Op>
382+
inline void broadcasting_map_broadcast_last_dim(
383+
const Op& vec_fun,
384+
scalar_t* output_data,
385+
const scalar_t* lhs,
386+
const scalar_t* rhs,
387+
int64_t outer_size,
388+
int64_t broadcast_size) {
389+
using Vec = vec::Vectorized<scalar_t>;
390+
int64_t outer_stride_lhs = broadcast_size;
391+
int64_t outer_stride_rhs = 1;
392+
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
393+
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
394+
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
395+
int64_t inner_idx = 0;
396+
Vec data_vec2 = Vec(rhs[outer_idx]);
397+
for (; inner_idx < broadcast_size - (broadcast_size % Vec::size()); inner_idx += Vec::size()) {
398+
Vec data_vec = Vec::loadu(lhs_outer + inner_idx);
399+
Vec output_vec = vec_fun(data_vec, data_vec2);
400+
output_vec.store(output_data_row + inner_idx);
401+
}
402+
if (broadcast_size - inner_idx > 0) {
403+
Vec data_vec = Vec::loadu(lhs_outer + inner_idx, broadcast_size - inner_idx);
404+
Vec output_vec = vec_fun(data_vec, data_vec2);
405+
output_vec.store(output_data_row + inner_idx, broadcast_size - inner_idx);
406+
}
407+
}
408+
}
409+
381410
} // namespace vec
382411
} // namespace executorch

0 commit comments

Comments
 (0)