Skip to content

Commit 83f0eb3

Browse files
committed
Update on "[ET-VK] Enable buffer implementation of aten.linear"
## Changes As title. Extend the existing buffer implementation of `matmul` to support the linear operator as well. Differential Revision: [D65277712](https://our.internmc.facebook.com/intern/diff/D65277712/) [ghstack-poisoned]
2 parents 6f9a236 + 972a62b commit 83f0eb3

File tree

12 files changed

+50
-268
lines changed

12 files changed

+50
-268
lines changed

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ image_to_nchw:
1919
- NAME: image_to_nchw_texture3d
2020
- NAME: image_to_nchw_texture2d
2121
STORAGE: texture2d
22-
- NAME: image_to_buffer
22+
- NAME: clone_image_to_buffer
2323
TO_STAGING: False

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ nchw_to_image:
1919
- NAME: nchw_to_image_texture3d
2020
- NAME: nchw_to_image_texture2d
2121
STORAGE: texture2d
22-
- NAME: buffer_to_image
22+
- NAME: clone_buffer_to_image
2323
FROM_STAGING: False

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ void resize_clone_node(
2525
(void)extra_args;
2626
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
2727
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
28-
out->virtual_resize(in->sizes());
28+
// TODO: support for when dimensionality doesn't match, i.e. clone is used to
29+
// implement squeeze.
30+
if (out->dim() == in->dim()) {
31+
out->virtual_resize(in->sizes());
32+
}
2933
}
3034

3135
void add_clone_node(
@@ -56,7 +60,7 @@ void add_image_to_buffer_node(
5660
ComputeGraph& graph,
5761
const ValueRef image,
5862
const ValueRef buffer) {
59-
std::string kernel_name = "image_to_buffer";
63+
std::string kernel_name = "clone_image_to_buffer";
6064
add_dtype_suffix(kernel_name, graph.dtype_of(image));
6165
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
6266

@@ -80,7 +84,7 @@ void add_buffer_to_image_node(
8084
ComputeGraph& graph,
8185
const ValueRef buffer,
8286
const ValueRef image) {
83-
std::string kernel_name = "buffer_to_image";
87+
std::string kernel_name = "clone_buffer_to_image";
8488
add_dtype_suffix(kernel_name, graph.dtype_of(image));
8589
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
8690

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,10 @@ void add_matmul_naive_buffer_node(
7777
graph.size_at<uint32_t>(-2, out),
7878
graph.size_at<uint32_t>(-3, out) * graph.size_at<uint32_t>(-4, out)};
7979

80-
int mat2_is_transposed_val = 0;
81-
if (mat2_is_transposed != kDummyValueRef &&
82-
graph.get_bool(mat2_is_transposed)) {
83-
mat2_is_transposed_val = 1;
84-
}
80+
int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef &&
81+
graph.get_bool(mat2_is_transposed))
82+
? 1
83+
: 0;
8584

8685
graph.execute_nodes().emplace_back(new DispatchNode(
8786
graph,

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader(
5656
if (is_bitw8(v_src.dtype()) && v_src.storage_type() != utils::kBuffer &&
5757
!int8_buffer_enabled) {
5858
kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
59-
add_dtype_suffix(kernel_name, v_src);
6059
add_storage_type_suffix(kernel_name, v_src);
60+
add_dtype_suffix(kernel_name, v_src);
6161
return VK_KERNEL_FROM_STR(kernel_name);
6262
}
6363

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ void record_bitw8_image_to_nchw_nobitw8buffer_op(
118118
utils::uvec3 global_wg_size = {buffer_len, 1, 1};
119119

120120
std::string kernel_name = "bitw8_image_to_nchw_nobitw8buffer";
121-
add_dtype_suffix(kernel_name, v_src);
122121
add_storage_type_suffix(kernel_name, v_src);
122+
add_dtype_suffix(kernel_name, v_src);
123123

124124
context->submit_compute_job(
125125
VK_KERNEL_FROM_STR(kernel_name),

extension/llm/custom_ops/targets.bzl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2-
load(
3-
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
4-
"get_compiler_optimization_flags",
5-
)
6-
72

83
def define_common_targets():
94
"""Defines targets that should be shared between fbcode and xplat.
@@ -39,7 +34,7 @@ def define_common_targets():
3934
"//executorch/kernels/portable/cpu/util:reduce_util",
4035
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
4136
],
42-
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(),
37+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
4338
visibility = [
4439
"//executorch/...",
4540
"//executorch/extension/llm/custom_ops/...",

kernels/optimized/cpu/binary_ops.h

Lines changed: 2 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -41,62 +41,10 @@ enum class ElementwiseOptimizedPath {
4141
kTreatAs1d,
4242
kBroadcast2dBy1d,
4343
kBroadcast2dBy1dReverseArguments,
44-
kBroadcastNdByNd,
45-
kBroadcastNdByNdReverseArguments,
4644
};
4745

4846
namespace internal {
49-
50-
// Find the single broadcast dimension if it exists.
51-
// This path aims to handle broadcast of the following form
52-
// A = [a1, a2,., 1, .., an]
53-
// B = [b1, b2,., bm, .., bn]
54-
// OR
55-
// A = [a1, a2,., am, .., an]
56-
// B = [b1, b2,., 1, .., bn]
57-
int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) {
58-
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
59-
auto lhs_end = lhs.sizes().end();
60-
61-
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
62-
auto rhs_end = rhs.sizes().end();
63-
64-
const auto lhs_size = lhs_end - lhs_begin;
65-
const auto rhs_size = rhs_end - rhs_begin;
66-
67-
// Following example is not handled at the moment
68-
// [1, 3, 4, 5]
69-
// [2, 3, 4, 5]
70-
if (lhs_size != rhs_size) {
71-
return 0;
72-
}
73-
74-
int32_t broadcast_dim = 0;
75-
// Check
76-
// 1. if any dim value is 1 (it constitutes a broadcast dim)
77-
// 2. If more than one dim value is 1 (we cannot handle)
78-
// 3. If non-1 dim values are equal
79-
lhs_end--;
80-
rhs_end--;
81-
while (lhs_end != lhs_begin) {
82-
if (*lhs_end == 1 || *rhs_end == 1) {
83-
// If more than one broadcast dim is found, return 0.
84-
if (broadcast_dim != 0) {
85-
return 0;
86-
}
87-
// negative index is used
88-
broadcast_dim = lhs_end - lhs.sizes().end();
89-
} else if (*lhs_end != *rhs_end) {
90-
// If non-1 dim values are not equal, return 0.
91-
return 0;
92-
}
93-
lhs_end--;
94-
rhs_end--;
95-
}
96-
return broadcast_dim;
97-
}
98-
99-
inline ElementwiseOptimizedPath select_broadcast_optimized_path(
47+
inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
10048
const Tensor& lhs,
10149
const Tensor& rhs) {
10250
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
@@ -115,17 +63,6 @@ inline ElementwiseOptimizedPath select_broadcast_optimized_path(
11563
return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
11664
}
11765

118-
int32_t broadcast_dim = get_broadcast_dim(lhs, rhs);
119-
// Right now we dont handle last dim broadcast
120-
if (broadcast_dim < -1) {
121-
if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) {
122-
return x == 1;
123-
}) == 1) {
124-
return ElementwiseOptimizedPath::kBroadcastNdByNd;
125-
} else {
126-
return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
127-
}
128-
}
12966
return ElementwiseOptimizedPath::kNone;
13067
}
13168
} // namespace internal
@@ -148,28 +85,7 @@ ElementwiseOptimizedPath inline select_optimized_path(
14885
internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
14986
return ElementwiseOptimizedPath::kTreatAs1d;
15087
}
151-
return internal::select_broadcast_optimized_path(a, b);
152-
}
153-
154-
std::array<int32_t, 3> inline get_normalized_tensor_size(
155-
const Tensor& a,
156-
const int32_t broadcast_dim) {
157-
ET_CHECK_MSG(
158-
a.dim() > broadcast_dim,
159-
"Size of tensor: %zd, must be larger than broadcast_dim: %d",
160-
a.dim(),
161-
broadcast_dim);
162-
std::array<int32_t, 3> normalized_tensor_size;
163-
normalized_tensor_size[0] = 1;
164-
normalized_tensor_size[1] = a.size(broadcast_dim);
165-
normalized_tensor_size[2] = 1;
166-
for (size_t i = 0; i < broadcast_dim; i++) {
167-
normalized_tensor_size[0] *= a.size(i);
168-
}
169-
for (size_t i = broadcast_dim + 1; i < a.dim(); i++) {
170-
normalized_tensor_size[2] *= a.size(i);
171-
}
172-
return normalized_tensor_size;
88+
return internal::select_broadcast_2d_by_1d_optimized_path(a, b);
17389
}
17490

17591
} // namespace executor

kernels/optimized/cpu/op_mul.cpp

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,15 @@ Tensor& opt_mul_out(
130130
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
131131
const Tensor* lhs;
132132
const Tensor* rhs;
133-
if ((selected_optimized_path ==
134-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
135-
(selected_optimized_path ==
136-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
133+
if (selected_optimized_path ==
134+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
137135
lhs = &b;
138136
rhs = &a;
139137
} else {
140138
// Catch failure to update logic when adding new broadcasting possibility.
141139
ET_DCHECK(
142-
(selected_optimized_path ==
143-
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
144-
(selected_optimized_path ==
145-
ElementwiseOptimizedPath::kBroadcastNdByNd));
140+
selected_optimized_path ==
141+
ElementwiseOptimizedPath::kBroadcast2dBy1d);
146142
lhs = &a;
147143
rhs = &b;
148144
}
@@ -153,34 +149,15 @@ Tensor& opt_mul_out(
153149
InvalidArgument,
154150
out,
155151
"Failed to resize output tensor.");
156-
int64_t outer_size = 1;
157-
int64_t broadcast_size;
158-
int64_t inner_size;
159-
if ((selected_optimized_path ==
160-
ElementwiseOptimizedPath::kBroadcastNdByNd) ||
161-
(selected_optimized_path ==
162-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
163-
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
164-
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
165-
auto normalized_tensor_size_lhs =
166-
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
167-
outer_size = normalized_tensor_size_lhs[0];
168-
broadcast_size = normalized_tensor_size_lhs[1];
169-
inner_size = normalized_tensor_size_lhs[2];
170-
} else {
171-
broadcast_size = lhs->sizes()[lhs->dim() - 2];
172-
inner_size = lhs->sizes()[lhs->dim() - 1];
173-
}
174152
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
175153
using Vec = executorch::vec::Vectorized<CTYPE>;
176-
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
154+
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
177155
[](Vec x, Vec y) { return x * y; },
178156
out.mutable_data_ptr<CTYPE>(),
179157
lhs->const_data_ptr<CTYPE>(),
180158
rhs->const_data_ptr<CTYPE>(),
181-
outer_size,
182-
broadcast_size,
183-
inner_size);
159+
lhs->sizes()[lhs->dim() - 2],
160+
lhs->sizes()[lhs->dim() - 1]);
184161
});
185162
} else {
186163
ScalarType common_type =

kernels/optimized/vec/functional_base.h

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -326,49 +326,10 @@ inline void map4(
326326
}
327327

328328

329-
// This function implements broadcasting binary operation on two tensors
330-
// where lhs tensor is treated to be of shape [outer_size, broadcast_size, inner_size]
331-
// and rhs tensor is treated to be of shape [outer_size, 1, inner_size]
332-
// And this 1st dimension is considered broadcasting dimension
333-
// This formula can map broadcasting on any dim=broadcast_dim
334-
// for any two N dimensional tensors, where 0 < braodcast_dim < N-1
335-
template <typename scalar_t, typename Op>
336-
inline void broadcasting_map_3d_and_unsqueezed_3d(
337-
const Op& vec_fun,
338-
scalar_t* output_data,
339-
const scalar_t* lhs,
340-
const scalar_t* rhs,
341-
int64_t outer_size,
342-
int64_t broadcast_size,
343-
int64_t inner_size) {
344-
using Vec = vec::Vectorized<scalar_t>;
345-
int64_t outer_stride_lhs = inner_size * broadcast_size;
346-
int64_t outer_stride_rhs = inner_size;
347-
int64_t broadcast_stride_lhs = inner_size;
348-
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
349-
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
350-
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
351-
const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs;
352-
for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) {
353-
const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs;
354-
scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs;
355-
int64_t inner_idx = 0;
356-
for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) {
357-
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx);
358-
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx);
359-
Vec output_vec = vec_fun(data_vec, data_vec2);
360-
output_vec.store(output_data_row_2 + inner_idx);
361-
}
362-
if (inner_size - inner_idx > 0) {
363-
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx);
364-
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx);
365-
Vec output_vec = vec_fun(data_vec, data_vec2);
366-
output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx);
367-
}
368-
}
369-
}
370-
}
371-
329+
// Map vec_fun across input_data and input_data2, where input_data is
330+
// a two-dimensional array of size (size, size2), input_data2 is a
331+
// one-dimensional array of size size2, and input_data2 is broadcast
332+
// to be of size (size, size2).
372333
template <typename scalar_t, typename Op>
373334
inline void broadcasting_map_2d_by_1d(
374335
const Op& vec_fun,
@@ -377,8 +338,27 @@ inline void broadcasting_map_2d_by_1d(
377338
const scalar_t* input_data2,
378339
int64_t size,
379340
int64_t size2) {
380-
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
341+
using Vec = vec::Vectorized<scalar_t>;
342+
for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) {
343+
const scalar_t* input_data_row = input_data + outer_idx * size2;
344+
scalar_t* output_data_row = output_data + outer_idx * size2;
345+
int64_t inner_idx = 0;
346+
for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) {
347+
Vec data_vec = Vec::loadu(input_data_row + inner_idx);
348+
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx);
349+
Vec output_vec = vec_fun(data_vec, data_vec2);
350+
output_vec.store(output_data_row + inner_idx);
351+
}
352+
if (size2 - inner_idx > 0) {
353+
Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx);
354+
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx);
355+
Vec output_vec = vec_fun(data_vec, data_vec2);
356+
output_vec.store(output_data_row + inner_idx, size2 - inner_idx);
357+
}
358+
}
381359
}
382360

361+
362+
383363
} // namespace vec
384364
} // namespace executorch

0 commit comments

Comments
 (0)