Skip to content

Commit 191e6c4

Browse files
committed
Update on "[ET-VK] Enable int8 tiled compute shader to be used with buffer tensors"
## Context As title. Allow the optimized int8 tiled compute shader to be usable for buffer-backed tensors as well. ## Changes * Generate buffer variants for the int8 linear tiled shader * Force the scales tensor to always be a buffer to reduce the number of shader variants that need to be generated. * Generate an additional variant that computes only 1 output row * Do not require output rows to be an exact multiple of 4 or 6 to use the tiled implementation Differential Revision: [D73276277](https://our.internmc.facebook.com/intern/diff/D73276277/) [ghstack-poisoned]
2 parents 55cb072 + 2639ea1 commit 191e6c4

File tree

22 files changed

+431
-19
lines changed

22 files changed

+431
-19
lines changed

examples/models/llama/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.examples.models.llama.llama_transformer import Transformer
1919

2020
from executorch.examples.models.llama.model_args import ModelArgs
21+
from torchao.utils import TorchAOBaseTensor
2122

2223
try:
2324
from .fairseq2 import convert_to_llama_checkpoint
@@ -257,6 +258,9 @@ def __init__(self, **kwargs):
257258
strict=False,
258259
assign=True,
259260
) # self.model_ = Transformer(gptconf)
261+
for param in self.model_.parameters():
262+
if isinstance(param, TorchAOBaseTensor):
263+
param.requires_grad = False
260264
else:
261265
print("Checkpoint not provided, defaulting weights to zeros.")
262266
self.model_.to_empty(device="cpu")

extension/llm/export/builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
4242
from torch.export import export_for_training, ExportedProgram
4343
from torch.nn.attention import SDPBackend
44+
from torchao.utils import unwrap_tensor_subclass
4445

4546
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
4647
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -199,6 +200,11 @@ def _get_edge_config(self) -> EdgeCompileConfig:
199200
return edge_config
200201

201202
def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
203+
if module is not None:
204+
unwrap_tensor_subclass(module)
205+
else:
206+
unwrap_tensor_subclass(self.model)
207+
202208
dynamic_shape = self._get_dynamic_shape()
203209
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
204210
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,8 @@
423423

424424
- op: var.out
425425

426+
- op: view_as_real_copy.out
427+
426428
- op: view_copy.out
427429

428430
- op: where.self_out
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <executorch/runtime/platform/assert.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace native {
16+
17+
using Tensor = executorch::aten::Tensor;
18+
19+
namespace {
20+
21+
template <typename SELF_CTYPE, typename OUT_CTYPE>
22+
inline void _to_impl(const Tensor& self, Tensor& out) {
23+
auto self_data = self.mutable_data_ptr<SELF_CTYPE>();
24+
auto out_data = out.mutable_data_ptr<OUT_CTYPE>();
25+
26+
for (size_t i = 0, e = self.numel(); i < e; i++) {
27+
auto val_in = self_data[i];
28+
out_data[2 * i] = static_cast<OUT_CTYPE>(val_in.real_);
29+
out_data[2 * i + 1] = static_cast<OUT_CTYPE>(val_in.imag_);
30+
}
31+
}
32+
33+
} // namespace
34+
35+
// view_as_real_copy(Tensor self) -> Tensor
36+
Tensor& view_as_real_copy_out(
37+
KernelRuntimeContext& ctx,
38+
const Tensor& self,
39+
Tensor& out) {
40+
(void)ctx;
41+
42+
// Get the output shape
43+
Tensor::SizesType expected_output_size[kTensorDimensionLimit];
44+
get_view_as_real_copy_out_target_size(self, expected_output_size);
45+
46+
// Resize for dynamic shape
47+
ET_KERNEL_CHECK_MSG(
48+
ctx,
49+
resize_tensor(
50+
out, {expected_output_size, static_cast<size_t>(out.dim())}) ==
51+
Error::Ok,
52+
InvalidArgument,
53+
out,
54+
"Failed to resize output tensor.");
55+
56+
// The input tensor must be complex type
57+
ET_KERNEL_CHECK_MSG(
58+
ctx,
59+
executorch::runtime::isComplexType(self.scalar_type()),
60+
InvalidArgument,
61+
out,
62+
"Input tensor must be complex type");
63+
64+
ET_KERNEL_CHECK(
65+
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
66+
67+
constexpr auto op_name = "view_as_real_copy.out";
68+
69+
ET_SWITCH_COMPLEXH_TYPES(self.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
70+
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
71+
_to_impl<CTYPE_IN, CTYPE_OUT>(self, out);
72+
});
73+
});
74+
75+
return out;
76+
}
77+
78+
} // namespace native
79+
} // namespace executor
80+
} // namespace torch

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,5 +1018,14 @@ void get_unfold_copy_out_target_size(
10181018
*out_ndim = self.dim() + 1;
10191019
}
10201020

1021+
void get_view_as_real_copy_out_target_size(
1022+
const Tensor& self,
1023+
executorch::aten::SizesType* out_sizes) {
1024+
for (auto i : c10::irange(self.dim())) {
1025+
out_sizes[i] = self.size(i);
1026+
}
1027+
out_sizes[self.dim()] = 2;
1028+
}
1029+
10211030
} // namespace executor
10221031
} // namespace torch

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,5 +247,9 @@ void get_unfold_copy_out_target_size(
247247
executorch::aten::SizesType* out_sizes,
248248
size_t* out_ndim);
249249

250+
void get_view_as_real_copy_out_target_size(
251+
const Tensor& self,
252+
executorch::aten::SizesType* out_sizes);
253+
250254
} // namespace executor
251255
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,11 @@
957957
- arg_meta: null
958958
kernel_name: torch::executor::var_out
959959

960+
- op: view_as_real_copy.out
961+
kernels:
962+
- arg_meta: null
963+
kernel_name: torch::executor::view_as_real_copy_out
964+
960965
- op: view_copy.out
961966
kernels:
962967
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ set(all_test_sources
242242
"op_upsample_bilinear2d_test.cpp"
243243
"op_upsample_nearest2d_test.cpp"
244244
"op_var_test.cpp"
245+
"op_view_as_real_copy_test.cpp"
245246
"op_view_copy_test.cpp"
246247
"op_where_test.cpp"
247248
"op_zeros_test.cpp"
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14+
15+
#include <gtest/gtest.h>
16+
17+
using namespace ::testing;
18+
using executorch::aten::ScalarType;
19+
using executorch::aten::Tensor;
20+
using torch::executor::testing::TensorFactory;
21+
22+
class OpViewAsRealTest : public OperatorTest {
23+
protected:
24+
Tensor& view_as_real_copy_out(const Tensor& self, Tensor& out) {
25+
return torch::executor::aten::view_as_real_copy_outf(context_, self, out);
26+
}
27+
28+
template <typename CTYPE, ScalarType DTYPE>
29+
void run_complex_smoke_test() {
30+
TensorFactory<DTYPE> tf;
31+
constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);
32+
TensorFactory<REAL_DTYPE> tf_out;
33+
34+
Tensor in = tf.make(
35+
{2, 2},
36+
{CTYPE(3, 4), CTYPE(-1.7, 7.4), CTYPE(5, -12), CTYPE(8.3, 0.1)});
37+
Tensor out = tf_out.zeros({2, 2, 2});
38+
Tensor expected =
39+
tf_out.make({2, 2, 2}, {3, 4, -1.7, 7.4, 5, -12, 8.3, 0.1});
40+
Tensor ret = view_as_real_copy_out(in, out);
41+
42+
EXPECT_TENSOR_EQ(out, ret);
43+
EXPECT_TENSOR_EQ(out, expected);
44+
}
45+
46+
// Tests on tensors with 0 size
47+
template <typename CTYPE, ScalarType DTYPE>
48+
void test_empty_input() {
49+
TensorFactory<DTYPE> tf;
50+
constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);
51+
TensorFactory<REAL_DTYPE> tf_out;
52+
53+
Tensor in = tf.make(/*sizes=*/{3, 0, 4}, /*data=*/{});
54+
Tensor out = tf_out.zeros({3, 0, 4, 2});
55+
Tensor expected = tf_out.make(/*sizes=*/{3, 0, 4, 2}, /*data=*/{});
56+
Tensor ret = view_as_real_copy_out(in, out);
57+
58+
EXPECT_TENSOR_EQ(out, ret);
59+
EXPECT_TENSOR_EQ(out, expected);
60+
}
61+
62+
// Tests on 0-dim input tensors
63+
template <typename CTYPE, ScalarType DTYPE>
64+
void zero_dim_input() {
65+
TensorFactory<DTYPE> tf;
66+
constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);
67+
TensorFactory<REAL_DTYPE> tf_out;
68+
69+
Tensor in = tf.make(/*sizes=*/{}, {CTYPE(0, 0)});
70+
Tensor out = tf_out.zeros({2});
71+
Tensor expected = tf_out.zeros(/*sizes=*/{2});
72+
Tensor ret = view_as_real_copy_out(in, out);
73+
74+
EXPECT_TENSOR_EQ(out, ret);
75+
EXPECT_TENSOR_EQ(out, expected);
76+
}
77+
};
78+
79+
TEST_F(OpViewAsRealTest, ComplexSmokeTest) {
80+
#define RUN_SMOKE_TEST(ctype, dtype) \
81+
run_complex_smoke_test<ctype, ScalarType::dtype>(); \
82+
test_empty_input<ctype, ScalarType::dtype>(); \
83+
zero_dim_input<ctype, ScalarType::dtype>();
84+
ET_FORALL_COMPLEXH_TYPES(RUN_SMOKE_TEST);
85+
#undef RUN_SMOKE_TEST
86+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def define_common_targets():
331331
_common_op_test("op_upsample_bilinear2d_test", ["aten", "portable"])
332332
_common_op_test("op_upsample_nearest2d_test", ["aten", "portable"])
333333
_common_op_test("op_var_test", ["aten", "portable"])
334+
_common_op_test("op_view_as_real_copy_test", ["aten", "portable"])
334335
_common_op_test("op_view_copy_test", ["aten", "portable"])
335336
_common_op_test("op_where_test", ["aten", "portable"])
336337
_common_op_test("op_zeros_test", ["aten", "portable"])

0 commit comments

Comments
 (0)