Skip to content

Commit a12c3f7

Browse files
committed
mul broadcast update
Differential Revision: [D64156862](https://our.internmc.facebook.com/intern/diff/D64156862/) ghstack-source-id: 248160685 Pull Request resolved: #6239
1 parent cd2d2b4 commit a12c3f7

File tree

3 files changed

+136
-29
lines changed

3 files changed

+136
-29
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,56 @@ enum class ElementwiseOptimizedPath {
4141
kTreatAs1d,
4242
kBroadcast2dBy1d,
4343
kBroadcast2dBy1dReverseArguments,
44+
kBroadcastNdByNd,
45+
kBroadcastNdByNdReverseArguments,
4446
};
4547

4648
namespace internal {
47-
inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
49+
50+
// Find the single broadcast dimension if it exists.
51+
int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) {
52+
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
53+
auto lhs_end = lhs.sizes().end();
54+
55+
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
56+
auto rhs_end = rhs.sizes().end();
57+
58+
const auto lhs_size = lhs_end - lhs_begin;
59+
const auto rhs_size = rhs_end - rhs_begin;
60+
61+
// Would like to handle this
62+
// [1, 3, 4, 5]
63+
// [2, 3, 4, 5]
64+
if (lhs_size != rhs_size) {
65+
return 0;
66+
}
67+
68+
int32_t broadcast_dim = 0;
69+
// Check
70+
// 1. if any dim value is 1 (it constitutes a broadcast dim)
71+
// 2. If more than one dim value is 1 (we cannot handle)
72+
// 3. If non-1 dim values are equal
73+
lhs_end--;
74+
rhs_end--;
75+
while (lhs_end != lhs_begin) {
76+
if (*lhs_end == 1 || *rhs_end == 1) {
77+
// If more than one broadcast dim is found, return 0.
78+
if (broadcast_dim != 0) {
79+
return 0;
80+
}
81+
// negative index is used
82+
broadcast_dim = lhs_end - lhs.sizes().end();
83+
} else if (*lhs_end != *rhs_end) {
84+
// If non-1 dim values are not equal, return 0.
85+
return 0;
86+
}
87+
lhs_end--;
88+
rhs_end--;
89+
}
90+
return broadcast_dim;
91+
}
92+
93+
inline ElementwiseOptimizedPath select_broadcast_optimized_path(
4894
const Tensor& lhs,
4995
const Tensor& rhs) {
5096
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
@@ -63,6 +109,15 @@ inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
63109
return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
64110
}
65111

112+
int32_t broadcast_dim = get_broadcast_dim(lhs, rhs);
113+
// Right now we dont handle last dim broadcast
114+
if (broadcast_dim < -1) {
115+
if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) { return x == 1; }) == 1) {
116+
return ElementwiseOptimizedPath::kBroadcastNdByNd;
117+
} else {
118+
return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
119+
}
120+
}
66121
return ElementwiseOptimizedPath::kNone;
67122
}
68123
} // namespace internal
@@ -85,7 +140,22 @@ ElementwiseOptimizedPath inline select_optimized_path(
85140
internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
86141
return ElementwiseOptimizedPath::kTreatAs1d;
87142
}
88-
return internal::select_broadcast_2d_by_1d_optimized_path(a, b);
143+
return internal::select_broadcast_optimized_path(a, b);
144+
}
145+
146+
std::array<int32_t, 3> inline get_normalized_tensor_size(const Tensor& a, const int32_t broadcast_dim) {
147+
ET_CHECK_MSG(a.dim() > broadcast_dim, "Size of tensor: %zd, must be larger than broadcast_dim: %d", a.dim(), broadcast_dim);
148+
std::array<int32_t, 3> normalized_tensor_size;
149+
normalized_tensor_size[0] = 1;
150+
normalized_tensor_size[1] = a.size(broadcast_dim);
151+
normalized_tensor_size[2] = 1;
152+
for (size_t i = 0; i < broadcast_dim; i++) {
153+
normalized_tensor_size[0] = normalized_tensor_size[0] * a.size(i);
154+
}
155+
for (size_t i = broadcast_dim + 1; i < a.dim(); i++) {
156+
normalized_tensor_size[2] = normalized_tensor_size[2] * a.size(i);
157+
}
158+
return normalized_tensor_size;
89159
}
90160

91161
} // namespace executor

kernels/optimized/cpu/op_mul.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,17 @@ Tensor& opt_mul_out(
130130
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
131131
const Tensor* lhs;
132132
const Tensor* rhs;
133-
if (selected_optimized_path ==
134-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
133+
if ((selected_optimized_path ==
134+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
135+
(selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
135136
lhs = &b;
136137
rhs = &a;
137138
} else {
138139
// Catch failure to update logic when adding new broadcasting possibility.
139140
ET_DCHECK(
140-
selected_optimized_path ==
141-
ElementwiseOptimizedPath::kBroadcast2dBy1d);
141+
(selected_optimized_path ==
142+
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
143+
(selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd));
142144
lhs = &a;
143145
rhs = &b;
144146
}
@@ -149,15 +151,32 @@ Tensor& opt_mul_out(
149151
InvalidArgument,
150152
out,
151153
"Failed to resize output tensor.");
154+
int64_t outer_size = 1;
155+
int64_t broadcast_size;
156+
int64_t inner_size;
157+
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
158+
(selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
159+
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
160+
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
161+
int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim;
162+
auto normalized_tensor_size_lhs = get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
163+
outer_size = normalized_tensor_size_lhs[0];
164+
broadcast_size = normalized_tensor_size_lhs[1];
165+
inner_size = normalized_tensor_size_lhs[2];
166+
} else {
167+
broadcast_size = lhs->sizes()[lhs->dim() - 2];
168+
inner_size = lhs->sizes()[lhs->dim() - 1];
169+
}
152170
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
153171
using Vec = executorch::vec::Vectorized<CTYPE>;
154-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
172+
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
155173
[](Vec x, Vec y) { return x * y; },
156174
out.mutable_data_ptr<CTYPE>(),
157175
lhs->const_data_ptr<CTYPE>(),
158176
rhs->const_data_ptr<CTYPE>(),
159-
lhs->sizes()[lhs->dim() - 2],
160-
lhs->sizes()[lhs->dim() - 1]);
177+
outer_size,
178+
broadcast_size,
179+
inner_size);
161180
});
162181
} else {
163182
ScalarType common_type =

kernels/optimized/vec/functional_base.h

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,43 @@ inline void map4(
330330
// a two-dimensional array of size (size, size2), input_data2 is a
331331
// one-dimensional array of size size2, and input_data2 is broadcast
332332
// to be of size (size, size2).
333+
template <typename scalar_t, typename Op>
334+
inline void broadcasting_map_3d_and_unsqueezed_3d(
335+
const Op& vec_fun,
336+
scalar_t* output_data,
337+
const scalar_t* lhs,
338+
const scalar_t* rhs,
339+
int64_t outer_size,
340+
int64_t broadcast_size,
341+
int64_t inner_size) {
342+
using Vec = vec::Vectorized<scalar_t>;
343+
int64_t outer_stride_lhs = inner_size * broadcast_size;
344+
int64_t outer_stride_rhs = inner_size;
345+
int64_t broadcast_stride_lhs = inner_size;
346+
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
347+
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
348+
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
349+
const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs;
350+
for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) {
351+
const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs;
352+
scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs;
353+
int64_t inner_idx = 0;
354+
for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) {
355+
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx);
356+
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx);
357+
Vec output_vec = vec_fun(data_vec, data_vec2);
358+
output_vec.store(output_data_row_2 + inner_idx);
359+
}
360+
if (inner_size - inner_idx > 0) {
361+
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx);
362+
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx);
363+
Vec output_vec = vec_fun(data_vec, data_vec2);
364+
output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx);
365+
}
366+
}
367+
}
368+
}
369+
333370
template <typename scalar_t, typename Op>
334371
inline void broadcasting_map_2d_by_1d(
335372
const Op& vec_fun,
@@ -338,27 +375,8 @@ inline void broadcasting_map_2d_by_1d(
338375
const scalar_t* input_data2,
339376
int64_t size,
340377
int64_t size2) {
341-
using Vec = vec::Vectorized<scalar_t>;
342-
for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) {
343-
const scalar_t* input_data_row = input_data + outer_idx * size2;
344-
scalar_t* output_data_row = output_data + outer_idx * size2;
345-
int64_t inner_idx = 0;
346-
for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) {
347-
Vec data_vec = Vec::loadu(input_data_row + inner_idx);
348-
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx);
349-
Vec output_vec = vec_fun(data_vec, data_vec2);
350-
output_vec.store(output_data_row + inner_idx);
351-
}
352-
if (size2 - inner_idx > 0) {
353-
Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx);
354-
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx);
355-
Vec output_vec = vec_fun(data_vec, data_vec2);
356-
output_vec.store(output_data_row + inner_idx, size2 - inner_idx);
357-
}
358-
}
378+
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
359379
}
360380

361-
362-
363381
} // namespace vec
364382
} // namespace executorch

0 commit comments

Comments
 (0)