Skip to content

Commit 335056c

Browse files
committed
Update base for Update on "[ET-VK] Allow clone op to transfer between memory layouts and storage types"
## Changes As title. Extend the functionality of the `aten.clone` operator to allow transitioning the storage type and memory layout between the input to the output tensor. ## Context This functionality will be used to transition input tensors to the optimal storage type and memory layout before entering the execution of an op. The transition nodes will be added by a memory metadata tagging pass that will be introduced in a subsequent diff. Differential Revision: [D65277710](https://our.internmc.facebook.com/intern/diff/D65277710/) [ghstack-poisoned]
2 parents d2cd73d + 1972e69 commit 335056c

File tree

8 files changed

+291
-65
lines changed

8 files changed

+291
-65
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,27 +91,23 @@ void main() {
9191

9292
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
9393

94-
VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {
95-
u16vec3 mat1_pos = u16vec3(0, out_pos.yz);
96-
u16vec3 qmat2_pos = u16vec3(0, out_pos.x * 4, 0);
94+
VEC4_T q_8w_linear(const u16vec3 out_pos, const uint16_t K) {
95+
const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4);
9796

9897
VEC4_T outtex = VEC4_T(0);
9998

10099
const u16vec3 scales_pos = u16vec3(out_pos.x, 0, 0);
101100
const VEC4_T scales = load_texel(t_scales, scales_pos);
102101

103-
for (int i = 0; i < K; i += 4) {
104-
const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);
102+
for (uint16_t i = uint16_t(0), x = uint16_t(0); i < K; i += uint16_t(4), x++) {
103+
const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.yz));
105104
const VEC4_T sums = VEC4_T(
106-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos)),
107-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 1, 0))),
108-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 2, 0))),
109-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 3, 0))));
105+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))),
106+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1), 0))),
107+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(2), 0))),
108+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(3), 0))));
110109

111110
outtex += sums;
112-
113-
mat1_pos.x++;
114-
qmat2_pos.x++;
115111
}
116112

117113
outtex *= scales;
@@ -120,12 +116,12 @@ VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {
120116
}
121117

122118
void main() {
123-
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
119+
const u16vec3 out_pos = u16vec3(gl_GlobalInvocationID);
124120
if (any(greaterThanEqual(out_pos, out_limits))) {
125121
return;
126122
}
127123

128-
VEC4_T outtex = q_8w_linear(out_pos, mat1_sizes.x);
124+
VEC4_T outtex = q_8w_linear(out_pos, uint16_t(mat1_sizes.x));
129125
write_texel(t_out, out_pos, outtex);
130126
}
131127

backends/vulkan/targets.bzl

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,28 +101,37 @@ def define_common_targets(is_fbcode = False):
101101
"fbsource//third-party/VulkanMemoryAllocator/3.0.1:VulkanMemoryAllocator_xplat",
102102
]
103103

104-
if not is_fbcode:
104+
if is_fbcode:
105105
VK_API_DEPS += [
106-
"fbsource//third-party/volk:volk",
106+
"fbsource//third-party/swiftshader:swiftshader_vk_headers",
107+
"fbsource//third-party/swiftshader/lib/linux-x64:libvk_swiftshader_fbcode",
108+
"fbsource//third-party/swiftshader/lib/linux-x64:libvk_swiftshader_so",
107109
]
110+
else:
108111
VK_API_DEPS += select({
109-
"DEFAULT": [],
110-
"ovr_config//os:android": ["fbsource//third-party/toolchains:android"],
112+
"DEFAULT": [
113+
"fbsource//third-party/volk:volk",
114+
],
115+
"ovr_config//os:android": [
116+
"fbsource//third-party/volk:volk",
117+
"fbsource//third-party/toolchains:android"
118+
],
119+
"ovr_config//os:macos-arm64": [
120+
"//third-party/khronos:moltenVK"
121+
],
111122
})
112-
VK_API_PREPROCESSOR_FLAGS += [
113-
"-DUSE_VULKAN_WRAPPER",
114-
"-DUSE_VULKAN_VOLK",
115-
]
116123
VK_API_PREPROCESSOR_FLAGS += select({
117-
"DEFAULT": [],
118-
"ovr_config//os:android": ["-DVK_ANDROID_external_memory_android_hardware_buffer"],
124+
"DEFAULT": [
125+
"-DUSE_VULKAN_WRAPPER",
126+
"-DUSE_VULKAN_VOLK",
127+
],
128+
"ovr_config//os:android": [
129+
"-DUSE_VULKAN_WRAPPER",
130+
"-DUSE_VULKAN_VOLK",
131+
"-DVK_ANDROID_external_memory_android_hardware_buffer"
132+
],
133+
"ovr_config//os:macos-arm64": []
119134
})
120-
else:
121-
VK_API_DEPS += [
122-
"fbsource//third-party/swiftshader:swiftshader_vk_headers",
123-
"fbsource//third-party/swiftshader/lib/linux-x64:libvk_swiftshader_fbcode",
124-
"fbsource//third-party/swiftshader/lib/linux-x64:libvk_swiftshader_so",
125-
]
126135

127136
runtime.cxx_library(
128137
name = "vulkan_compute_api",

extension/llm/custom_ops/targets.bzl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(
3+
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
4+
"get_compiler_optimization_flags",
5+
)
6+
27

38
def define_common_targets():
49
"""Defines targets that should be shared between fbcode and xplat.
@@ -34,7 +39,7 @@ def define_common_targets():
3439
"//executorch/kernels/portable/cpu/util:reduce_util",
3540
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
3641
],
37-
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"],
42+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(),
3843
visibility = [
3944
"//executorch/...",
4045
"//executorch/extension/llm/custom_ops/...",

kernels/optimized/cpu/binary_ops.h

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,62 @@ enum class ElementwiseOptimizedPath {
4141
kTreatAs1d,
4242
kBroadcast2dBy1d,
4343
kBroadcast2dBy1dReverseArguments,
44+
kBroadcastNdByNd,
45+
kBroadcastNdByNdReverseArguments,
4446
};
4547

4648
namespace internal {
47-
inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
49+
50+
// Find the single broadcast dimension if it exists.
51+
// This path aims to handle broadcast of the following form
52+
// A = [a1, a2,., 1, .., an]
53+
// B = [b1, b2,., bm, .., bn]
54+
// OR
55+
// A = [a1, a2,., am, .., an]
56+
// B = [b1, b2,., 1, .., bn]
57+
int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) {
58+
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
59+
auto lhs_end = lhs.sizes().end();
60+
61+
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
62+
auto rhs_end = rhs.sizes().end();
63+
64+
const auto lhs_size = lhs_end - lhs_begin;
65+
const auto rhs_size = rhs_end - rhs_begin;
66+
67+
// Following example is not handled at the moment
68+
// [1, 3, 4, 5]
69+
// [2, 3, 4, 5]
70+
if (lhs_size != rhs_size) {
71+
return 0;
72+
}
73+
74+
int32_t broadcast_dim = 0;
75+
// Check
76+
// 1. if any dim value is 1 (it constitutes a broadcast dim)
77+
// 2. If more than one dim value is 1 (we cannot handle)
78+
// 3. If non-1 dim values are equal
79+
lhs_end--;
80+
rhs_end--;
81+
while (lhs_end != lhs_begin) {
82+
if (*lhs_end == 1 || *rhs_end == 1) {
83+
// If more than one broadcast dim is found, return 0.
84+
if (broadcast_dim != 0) {
85+
return 0;
86+
}
87+
// negative index is used
88+
broadcast_dim = lhs_end - lhs.sizes().end();
89+
} else if (*lhs_end != *rhs_end) {
90+
// If non-1 dim values are not equal, return 0.
91+
return 0;
92+
}
93+
lhs_end--;
94+
rhs_end--;
95+
}
96+
return broadcast_dim;
97+
}
98+
99+
inline ElementwiseOptimizedPath select_broadcast_optimized_path(
48100
const Tensor& lhs,
49101
const Tensor& rhs) {
50102
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
@@ -63,6 +115,17 @@ inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
63115
return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
64116
}
65117

118+
int32_t broadcast_dim = get_broadcast_dim(lhs, rhs);
119+
// Right now we dont handle last dim broadcast
120+
if (broadcast_dim < -1) {
121+
if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) {
122+
return x == 1;
123+
}) == 1) {
124+
return ElementwiseOptimizedPath::kBroadcastNdByNd;
125+
} else {
126+
return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
127+
}
128+
}
66129
return ElementwiseOptimizedPath::kNone;
67130
}
68131
} // namespace internal
@@ -85,7 +148,28 @@ ElementwiseOptimizedPath inline select_optimized_path(
85148
internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
86149
return ElementwiseOptimizedPath::kTreatAs1d;
87150
}
88-
return internal::select_broadcast_2d_by_1d_optimized_path(a, b);
151+
return internal::select_broadcast_optimized_path(a, b);
152+
}
153+
154+
std::array<int32_t, 3> inline get_normalized_tensor_size(
155+
const Tensor& a,
156+
const int32_t broadcast_dim) {
157+
ET_CHECK_MSG(
158+
a.dim() > broadcast_dim,
159+
"Size of tensor: %zd, must be larger than broadcast_dim: %d",
160+
a.dim(),
161+
broadcast_dim);
162+
std::array<int32_t, 3> normalized_tensor_size;
163+
normalized_tensor_size[0] = 1;
164+
normalized_tensor_size[1] = a.size(broadcast_dim);
165+
normalized_tensor_size[2] = 1;
166+
for (size_t i = 0; i < broadcast_dim; i++) {
167+
normalized_tensor_size[0] *= a.size(i);
168+
}
169+
for (size_t i = broadcast_dim + 1; i < a.dim(); i++) {
170+
normalized_tensor_size[2] *= a.size(i);
171+
}
172+
return normalized_tensor_size;
89173
}
90174

91175
} // namespace executor

kernels/optimized/cpu/op_mul.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,19 @@ Tensor& opt_mul_out(
130130
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
131131
const Tensor* lhs;
132132
const Tensor* rhs;
133-
if (selected_optimized_path ==
134-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
133+
if ((selected_optimized_path ==
134+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
135+
(selected_optimized_path ==
136+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
135137
lhs = &b;
136138
rhs = &a;
137139
} else {
138140
// Catch failure to update logic when adding new broadcasting possibility.
139141
ET_DCHECK(
140-
selected_optimized_path ==
141-
ElementwiseOptimizedPath::kBroadcast2dBy1d);
142+
(selected_optimized_path ==
143+
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
144+
(selected_optimized_path ==
145+
ElementwiseOptimizedPath::kBroadcastNdByNd));
142146
lhs = &a;
143147
rhs = &b;
144148
}
@@ -149,15 +153,34 @@ Tensor& opt_mul_out(
149153
InvalidArgument,
150154
out,
151155
"Failed to resize output tensor.");
156+
int64_t outer_size = 1;
157+
int64_t broadcast_size;
158+
int64_t inner_size;
159+
if ((selected_optimized_path ==
160+
ElementwiseOptimizedPath::kBroadcastNdByNd) ||
161+
(selected_optimized_path ==
162+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
163+
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
164+
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
165+
auto normalized_tensor_size_lhs =
166+
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
167+
outer_size = normalized_tensor_size_lhs[0];
168+
broadcast_size = normalized_tensor_size_lhs[1];
169+
inner_size = normalized_tensor_size_lhs[2];
170+
} else {
171+
broadcast_size = lhs->sizes()[lhs->dim() - 2];
172+
inner_size = lhs->sizes()[lhs->dim() - 1];
173+
}
152174
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
153175
using Vec = executorch::vec::Vectorized<CTYPE>;
154-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
176+
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
155177
[](Vec x, Vec y) { return x * y; },
156178
out.mutable_data_ptr<CTYPE>(),
157179
lhs->const_data_ptr<CTYPE>(),
158180
rhs->const_data_ptr<CTYPE>(),
159-
lhs->sizes()[lhs->dim() - 2],
160-
lhs->sizes()[lhs->dim() - 1]);
181+
outer_size,
182+
broadcast_size,
183+
inner_size);
161184
});
162185
} else {
163186
ScalarType common_type =

kernels/optimized/vec/functional_base.h

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,49 @@ inline void map4(
326326
}
327327

328328

329-
// Map vec_fun across input_data and input_data2, where input_data is
330-
// a two-dimensional array of size (size, size2), input_data2 is a
331-
// one-dimensional array of size size2, and input_data2 is broadcast
332-
// to be of size (size, size2).
329+
// This function implements broadcasting binary operation on two tensors
330+
// where lhs tensor is treated to be of shape [outer_size, broadcast_size, inner_size]
331+
// and rhs tensor is treated to be of shape [outer_size, 1, inner_size]
332+
// And this 1st dimension is considered broadcasting dimension
333+
// This formula can map broadcasting on any dim=broadcast_dim
334+
// for any two N dimensional tensors, where 0 < braodcast_dim < N-1
335+
template <typename scalar_t, typename Op>
336+
inline void broadcasting_map_3d_and_unsqueezed_3d(
337+
const Op& vec_fun,
338+
scalar_t* output_data,
339+
const scalar_t* lhs,
340+
const scalar_t* rhs,
341+
int64_t outer_size,
342+
int64_t broadcast_size,
343+
int64_t inner_size) {
344+
using Vec = vec::Vectorized<scalar_t>;
345+
int64_t outer_stride_lhs = inner_size * broadcast_size;
346+
int64_t outer_stride_rhs = inner_size;
347+
int64_t broadcast_stride_lhs = inner_size;
348+
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
349+
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
350+
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
351+
const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs;
352+
for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) {
353+
const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs;
354+
scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs;
355+
int64_t inner_idx = 0;
356+
for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) {
357+
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx);
358+
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx);
359+
Vec output_vec = vec_fun(data_vec, data_vec2);
360+
output_vec.store(output_data_row_2 + inner_idx);
361+
}
362+
if (inner_size - inner_idx > 0) {
363+
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx);
364+
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx);
365+
Vec output_vec = vec_fun(data_vec, data_vec2);
366+
output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx);
367+
}
368+
}
369+
}
370+
}
371+
333372
template <typename scalar_t, typename Op>
334373
inline void broadcasting_map_2d_by_1d(
335374
const Op& vec_fun,
@@ -338,27 +377,8 @@ inline void broadcasting_map_2d_by_1d(
338377
const scalar_t* input_data2,
339378
int64_t size,
340379
int64_t size2) {
341-
using Vec = vec::Vectorized<scalar_t>;
342-
for (int64_t outer_idx = 0; outer_idx < size; ++outer_idx) {
343-
const scalar_t* input_data_row = input_data + outer_idx * size2;
344-
scalar_t* output_data_row = output_data + outer_idx * size2;
345-
int64_t inner_idx = 0;
346-
for (; inner_idx < size2 - (size2 % Vec::size()); inner_idx += Vec::size()) {
347-
Vec data_vec = Vec::loadu(input_data_row + inner_idx);
348-
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx);
349-
Vec output_vec = vec_fun(data_vec, data_vec2);
350-
output_vec.store(output_data_row + inner_idx);
351-
}
352-
if (size2 - inner_idx > 0) {
353-
Vec data_vec = Vec::loadu(input_data_row + inner_idx, size2 - inner_idx);
354-
Vec data_vec2 = Vec::loadu(input_data2 + inner_idx, size2 - inner_idx);
355-
Vec output_vec = vec_fun(data_vec, data_vec2);
356-
output_vec.store(output_data_row + inner_idx, size2 - inner_idx);
357-
}
358-
}
380+
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
359381
}
360382

361-
362-
363383
} // namespace vec
364384
} // namespace executorch

0 commit comments

Comments
 (0)