Skip to content

Commit 827d50f

Browse files
committed
Update base for Update on "[ET-VK] Refine paritioner to account for storage type and memory layout"
## Context There are a variety of ways that tensors can be represented in Vulkan. The two main descriptors for how a tensor is laid out in memory is: 1. Storage Type (buffer or texture) 2. Memory Layout (which dim is packed along a texel, which dim has a stride of 1, etc.) Due to the differences between buffers and textures, and the differences between different memory layouts, an implementation for an operator may only support a specific set of (storage type, memory layout) combinations. Furthermore, if an operator implementation supports multiple (storage type, memory layout) combinations, there may be a "preferred" setting which results in optimal performance. These changes lay the foundation for the implementation of a memory metadata tagging graph transform, which will make sure that all tensors participating in an operator call is has a valid/optimal (storage type, memory layout) setting, and insert transition operators to transfer input tensors to the correct memory settings when necessary. An additional change that is required arises from the fact that in Vulkan, there is a limit on texture and buffer sizes. Therefore, the partitioner needs to account for the storage types and memory layouts supported by the operator implementation, and check if all tensors participating in a computation can be represented with some storage type, memory layout combination supported by the implementation. ## Changes Improvements to the operator registry: * Introduce utility functions to check the optimal and enabled storage types and memory layouts for an operator Improvements to the Partitioner: * Account for the storage types and memory layouts supported by an operator when deciding if a node should be partitioned * Improved logic for fusable ops (i.e. the permute/transpose before a mm which can be fused into linear) to check if the final target op is supported in Vulkan, and only partition those nodes if so. Otherwise, don't partition it so that it can be fused by another backend. Differential Revision: [D65428843](https://our.internmc.facebook.com/intern/diff/D65428843/) [ghstack-poisoned]
2 parents 09cf982 + 244546b commit 827d50f

File tree

10 files changed

+160
-41
lines changed

10 files changed

+160
-41
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e47e8794499a4a0130ff4efb8713ff93f4b40c36
1+
c8a648d4dffb9f0133ff4a2ea0e660b42105d3ad

backends/arm/quantizer/TARGETS

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
33
python_library(
44
name = "arm_quantizer",
55
srcs = ["arm_quantizer.py"],
6-
typing = True,
76
deps = [
87
":arm_quantizer_utils",
98
"//caffe2:torch",
@@ -15,7 +14,6 @@ python_library(
1514
python_library(
1615
name = "quantization_config",
1716
srcs = ["quantization_config.py"],
18-
typing = True,
1917
deps = [
2018
"//caffe2:torch",
2119
],
@@ -24,7 +22,6 @@ python_library(
2422
python_library(
2523
name = "arm_quantizer_utils",
2624
srcs = ["arm_quantizer_utils.py"],
27-
typing = True,
2825
deps = [
2926
":quantization_config",
3027
],

backends/cadence/aot/functions.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@
154154
kernels:
155155
- arg_meta: null
156156
kernel_name: impl::reference::quantized_layer_norm_out
157+
- func: cadence::quantized_layer_norm.per_tensor_out(Tensor input, float in_scale, int in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
158+
kernels:
159+
- arg_meta: null
160+
kernel_name: impl::reference::quantized_layer_norm_per_tensor_out
157161

158162
- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
159163
kernels:

backends/cadence/aot/functions_hifi.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@
125125
kernels:
126126
- arg_meta: null
127127
kernel_name: cadence::impl::HiFi::quantized_layer_norm_out
128+
- func: cadence::quantized_layer_norm.per_tensor_out(Tensor input, float in_scale, int in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
129+
kernels:
130+
- arg_meta: null
131+
kernel_name: cadence::impl::HiFi::quantized_layer_norm_per_tensor_out
128132

129133
- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
130134
kernels:

backends/cadence/aot/ops_registrations.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
lib.define(
3737
"quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
3838
)
39+
lib.define(
40+
"quantized_layer_norm.per_tensor(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
41+
)
42+
lib.define(
43+
"quantized_layer_norm.per_tensor_out(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
44+
)
3945

4046
lib.define(
4147
"quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
@@ -180,6 +186,21 @@ def quantized_layer_norm_meta(
180186
return input.new_empty(input.size(), dtype=input.dtype)
181187

182188

189+
@register_fake("cadence::quantized_layer_norm.per_tensor")
190+
def quantized_layer_norm_per_tensor_meta(
191+
input: torch.Tensor,
192+
X_scale: float,
193+
X_zero_point: int,
194+
normalized_shape: int,
195+
weight: torch.Tensor,
196+
bias: torch.Tensor,
197+
eps: float,
198+
output_scale: float,
199+
output_zero_point: int,
200+
) -> torch.Tensor:
201+
return input.new_empty(input.size(), dtype=input.dtype)
202+
203+
183204
@register_fake("cadence::quantized_relu")
184205
def quantized_relu_meta(
185206
X: torch.Tensor,

backends/cadence/hifi/operators/quantized_layer_norm.cpp

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace native {
2727
// Compute quantized layer_norm. The current implementation assumes that the
2828
// input is per-tensor quantized.
2929
template <typename T>
30-
void quantized_layer_norm_(
30+
void quantized_layer_norm_per_tensor_(
3131
const Tensor& input,
3232
float input_scale,
3333
int64_t input_zero_point,
@@ -107,7 +107,7 @@ void quantized_layer_norm_(
107107
int64_t input_zero_point = in_zero_point.const_data_ptr<int64_t>()[0];
108108

109109
// Call other overload
110-
quantized_layer_norm_<T>(
110+
quantized_layer_norm_per_tensor_<T>(
111111
input,
112112
input_scale,
113113
input_zero_point,
@@ -120,7 +120,7 @@ void quantized_layer_norm_(
120120
}
121121

122122
void quantized_layer_norm_out(
123-
KernelRuntimeContext& ctx,
123+
__ET_UNUSED KernelRuntimeContext& ctx,
124124
const Tensor& input,
125125
const Tensor& in_scale,
126126
const Tensor& in_zero_point,
@@ -157,6 +157,44 @@ void quantized_layer_norm_out(
157157
#undef typed_quantized_layer_norm
158158
}
159159

160+
void quantized_layer_norm_per_tensor_out(
161+
__ET_UNUSED KernelRuntimeContext& ctx,
162+
const Tensor& input,
163+
double in_scale,
164+
int64_t in_zero_point,
165+
__ET_UNUSED const IntArrayRef normalized_shape,
166+
const Tensor& weight,
167+
const Tensor& bias,
168+
double eps,
169+
double output_scale,
170+
int64_t output_zero_point,
171+
Tensor& out) {
172+
#define typed_quantized_layer_norm(ctype, dtype) \
173+
case ScalarType::dtype: { \
174+
quantized_layer_norm_per_tensor_<ctype>( \
175+
input, \
176+
in_scale, \
177+
in_zero_point, \
178+
weight, \
179+
bias, \
180+
eps, \
181+
output_scale, \
182+
output_zero_point, \
183+
out); \
184+
break; \
185+
}
186+
187+
ScalarType dtype = input.scalar_type();
188+
switch (dtype) {
189+
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_layer_norm)
190+
default:
191+
ET_DCHECK_MSG(
192+
false, "Unhandled dtype %s", torch::executor::toString(dtype));
193+
}
194+
195+
#undef typed_quantized_layer_norm
196+
}
197+
160198
}; // namespace native
161199
}; // namespace HiFi
162200
}; // namespace impl

backends/cadence/reference/operators/quantized_layer_norm.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111

1212
#include <cmath>
1313

14-
using executorch::aten::Tensor;
15-
using executorch::runtime::getLeadingDims;
16-
using executorch::runtime::KernelRuntimeContext;
14+
using ::executorch::aten::IntArrayRef;
15+
using ::executorch::aten::ScalarType;
16+
using ::executorch::aten::Tensor;
17+
using ::executorch::runtime::getLeadingDims;
18+
using ::executorch::runtime::KernelRuntimeContext;
1719

1820
namespace impl {
1921
namespace reference {
@@ -22,7 +24,7 @@ namespace native {
2224
// Compute quantized layer_norm. The current implementation assumes that the
2325
// input is per-tensor quantized.
2426
template <typename T>
25-
void quantized_layer_norm_(
27+
void quantized_layer_norm_per_tensor_(
2628
const Tensor& input,
2729
double input_scale,
2830
int64_t input_zero_point,
@@ -98,7 +100,7 @@ void quantized_layer_norm_(
98100
int64_t input_zero_point = in_zero_point.const_data_ptr<int64_t>()[0];
99101

100102
// Call other overload
101-
quantized_layer_norm_<T>(
103+
quantized_layer_norm_per_tensor_<T>(
102104
input,
103105
input_scale,
104106
input_zero_point,
@@ -111,11 +113,11 @@ void quantized_layer_norm_(
111113
}
112114

113115
void quantized_layer_norm_out(
114-
KernelRuntimeContext& ctx,
116+
__ET_UNUSED KernelRuntimeContext& ctx,
115117
const Tensor& input,
116118
const Tensor& in_scale,
117119
const Tensor& in_zero_point,
118-
const executorch::aten::IntArrayRef normalized_shape,
120+
__ET_UNUSED const executorch::aten::IntArrayRef normalized_shape,
119121
const Tensor& weight,
120122
const Tensor& bias,
121123
double eps,
@@ -152,6 +154,48 @@ void quantized_layer_norm_out(
152154
}
153155
}
154156

157+
void quantized_layer_norm_per_tensor_out(
158+
__ET_UNUSED KernelRuntimeContext& ctx,
159+
const Tensor& input,
160+
double in_scale,
161+
int64_t in_zero_point,
162+
__ET_UNUSED const executorch::aten::IntArrayRef normalized_shape,
163+
const Tensor& weight,
164+
const Tensor& bias,
165+
double eps,
166+
double output_scale,
167+
int64_t output_zero_point,
168+
Tensor& out) {
169+
if (input.scalar_type() == executorch::aten::ScalarType::Byte) {
170+
quantized_layer_norm_per_tensor_<uint8_t>(
171+
input,
172+
in_scale,
173+
in_zero_point,
174+
weight,
175+
bias,
176+
eps,
177+
output_scale,
178+
output_zero_point,
179+
out);
180+
} else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
181+
quantized_layer_norm_per_tensor_<int8_t>(
182+
input,
183+
in_scale,
184+
in_zero_point,
185+
weight,
186+
bias,
187+
eps,
188+
output_scale,
189+
output_zero_point,
190+
out);
191+
} else {
192+
ET_CHECK_MSG(
193+
false,
194+
"Unhandled input dtype %hhd",
195+
static_cast<int8_t>(input.scalar_type()));
196+
}
197+
}
198+
155199
}; // namespace native
156200
}; // namespace reference
157201
}; // namespace impl

examples/models/llama3_2_vision/preprocess/export_preprocess.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,22 @@ def main():
2424
strict=False,
2525
)
2626

27-
# Executorch
27+
# AOTInductor. Note: export AOTI before ExecuTorch, as
28+
# ExecuTorch will modify the ExportedProgram.
29+
torch._inductor.aot_compile(
30+
ep.module(),
31+
model.get_example_inputs(),
32+
options={"aot_inductor.output_path": "preprocess_aoti.so"},
33+
)
34+
35+
# Executorch.
2836
edge_program = to_edge(
2937
ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)
3038
)
3139
et_program = edge_program.to_executorch()
3240
with open("preprocess_et.pte", "wb") as file:
3341
et_program.write_to_file(file)
3442

35-
# Export.
36-
# ep = torch.export.export(
37-
# model.get_eager_model(),
38-
# model.get_example_inputs(),
39-
# dynamic_shapes=model.get_dynamic_shapes(),
40-
# strict=False,
41-
# )
42-
#
43-
# # AOTInductor
44-
# torch._inductor.aot_compile(
45-
# ep.module(),
46-
# model.get_example_inputs(),
47-
# options={"aot_inductor.output_path": "preprocess_aoti.so"},
48-
# )
49-
5043

5144
if __name__ == "__main__":
5245
main()

examples/models/llama3_2_vision/preprocess/test_preprocess.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727

2828
from PIL import Image
29+
from torch._inductor.package import package_aoti
2930

3031
from torchtune.models.clip.inference._transform import CLIPImageTransform
3132

@@ -55,31 +56,46 @@ def initialize_models(resize_to_max_canvas: bool) -> Dict[str, Any]:
5556
possible_resolutions=None,
5657
)
5758

59+
# Eager model.
5860
model = CLIPImageTransformModel(config)
5961

62+
# Exported model.
6063
exported_model = torch.export.export(
6164
model.get_eager_model(),
6265
model.get_example_inputs(),
6366
dynamic_shapes=model.get_dynamic_shapes(),
6467
strict=False,
6568
)
6669

67-
# aoti_path = torch._inductor.aot_compile(
68-
# exported_model.module(),
69-
# model.get_example_inputs(),
70-
# )
70+
# AOTInductor model.
71+
so = torch._export.aot_compile(
72+
exported_model.module(),
73+
args=model.get_example_inputs(),
74+
options={"aot_inductor.package": True},
75+
dynamic_shapes=model.get_dynamic_shapes(),
76+
)
77+
aoti_path = "preprocess.pt2"
78+
package_aoti(aoti_path, so)
7179

7280
edge_program = to_edge(
7381
exported_model, compile_config=EdgeCompileConfig(_check_ir_validity=False)
7482
)
7583
executorch_model = edge_program.to_executorch()
7684

85+
# Re-export as ExecuTorch edits the ExportedProgram.
86+
exported_model = torch.export.export(
87+
model.get_eager_model(),
88+
model.get_example_inputs(),
89+
dynamic_shapes=model.get_dynamic_shapes(),
90+
strict=False,
91+
)
92+
7793
return {
7894
"config": config,
7995
"reference_model": reference_model,
8096
"model": model,
8197
"exported_model": exported_model,
82-
# "aoti_path": aoti_path,
98+
"aoti_path": aoti_path,
8399
"executorch_model": executorch_model,
84100
}
85101

@@ -265,11 +281,13 @@ def run_preprocess(
265281
), f"Executorch model: expected {reference_ar} but got {et_ar.tolist()}"
266282

267283
# Run aoti model and check it matches reference model.
268-
# aoti_path = models["aoti_path"]
269-
# aoti_model = torch._export.aot_load(aoti_path, "cpu")
270-
# aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
271-
# self.assertTrue(torch.allclose(reference_image, aoti_image))
272-
# self.assertEqual(reference_ar, aoti_ar.tolist())
284+
aoti_path = models["aoti_path"]
285+
aoti_model = torch._inductor.aoti_load_package(aoti_path)
286+
aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
287+
assert_expected(aoti_image, reference_image, rtol=0, atol=1e-4)
288+
assert (
289+
reference_ar == aoti_ar.tolist()
290+
), f"AOTI model: expected {reference_ar} but got {aoti_ar.tolist()}"
273291

274292
# This test setup mirrors the one in torchtune:
275293
# https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py

install_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def python_is_compatible():
112112
# NOTE: If a newly-fetched version of the executorch repo changes the value of
113113
# NIGHTLY_VERSION, you should re-run this script to install the necessary
114114
# package versions.
115-
NIGHTLY_VERSION = "dev20241030"
115+
NIGHTLY_VERSION = "dev20241101"
116116

117117
# The pip repository that hosts nightly torch packages.
118118
TORCH_NIGHTLY_URL = "https://download.pytorch.org/whl/nightly/cpu"

0 commit comments

Comments
 (0)