Skip to content

Commit 65e76ee

Browse files
committed
Update on "[ET-VK] Introduce memory metadata tagging pass"
## Context As title; implements the memory metadata tagging graph transform described in the dependent diff. See the comments for more details. Differential Revision: [D65428842](https://our.internmc.facebook.com/intern/diff/D65428842/) [ghstack-poisoned]
2 parents f0b6c06 + d347a97 commit 65e76ee

File tree

15 files changed

+288
-67
lines changed

15 files changed

+288
-67
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e47e8794499a4a0130ff4efb8713ff93f4b40c36
1+
c8a648d4dffb9f0133ff4a2ea0e660b42105d3ad

backends/arm/quantizer/TARGETS

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
33
python_library(
44
name = "arm_quantizer",
55
srcs = ["arm_quantizer.py"],
6-
typing = True,
76
deps = [
87
":arm_quantizer_utils",
98
"//caffe2:torch",
@@ -15,7 +14,6 @@ python_library(
1514
python_library(
1615
name = "quantization_config",
1716
srcs = ["quantization_config.py"],
18-
typing = True,
1917
deps = [
2018
"//caffe2:torch",
2119
],
@@ -24,7 +22,6 @@ python_library(
2422
python_library(
2523
name = "arm_quantizer_utils",
2624
srcs = ["arm_quantizer_utils.py"],
27-
typing = True,
2825
deps = [
2926
":quantization_config",
3027
],

backends/cadence/aot/functions.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@
154154
kernels:
155155
- arg_meta: null
156156
kernel_name: impl::reference::quantized_layer_norm_out
157+
- func: cadence::quantized_layer_norm.per_tensor_out(Tensor input, float in_scale, int in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
158+
kernels:
159+
- arg_meta: null
160+
kernel_name: impl::reference::quantized_layer_norm_per_tensor_out
157161

158162
- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
159163
kernels:

backends/cadence/aot/functions_hifi.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@
125125
kernels:
126126
- arg_meta: null
127127
kernel_name: cadence::impl::HiFi::quantized_layer_norm_out
128+
- func: cadence::quantized_layer_norm.per_tensor_out(Tensor input, float in_scale, int in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
129+
kernels:
130+
- arg_meta: null
131+
kernel_name: cadence::impl::HiFi::quantized_layer_norm_per_tensor_out
128132

129133
- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
130134
kernels:

backends/cadence/aot/ops_registrations.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
lib.define(
3737
"quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
3838
)
39+
lib.define(
40+
"quantized_layer_norm.per_tensor(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
41+
)
42+
lib.define(
43+
"quantized_layer_norm.per_tensor_out(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
44+
)
3945

4046
lib.define(
4147
"quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
@@ -180,6 +186,21 @@ def quantized_layer_norm_meta(
180186
return input.new_empty(input.size(), dtype=input.dtype)
181187

182188

189+
@register_fake("cadence::quantized_layer_norm.per_tensor")
190+
def quantized_layer_norm_per_tensor_meta(
191+
input: torch.Tensor,
192+
X_scale: float,
193+
X_zero_point: int,
194+
normalized_shape: int,
195+
weight: torch.Tensor,
196+
bias: torch.Tensor,
197+
eps: float,
198+
output_scale: float,
199+
output_zero_point: int,
200+
) -> torch.Tensor:
201+
return input.new_empty(input.size(), dtype=input.dtype)
202+
203+
183204
@register_fake("cadence::quantized_relu")
184205
def quantized_relu_meta(
185206
X: torch.Tensor,

backends/cadence/hifi/operators/quantized_layer_norm.cpp

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace native {
2727
// Compute quantized layer_norm. The current implementation assumes that the
2828
// input is per-tensor quantized.
2929
template <typename T>
30-
void quantized_layer_norm_(
30+
void quantized_layer_norm_per_tensor_(
3131
const Tensor& input,
3232
float input_scale,
3333
int64_t input_zero_point,
@@ -107,7 +107,7 @@ void quantized_layer_norm_(
107107
int64_t input_zero_point = in_zero_point.const_data_ptr<int64_t>()[0];
108108

109109
// Call other overload
110-
quantized_layer_norm_<T>(
110+
quantized_layer_norm_per_tensor_<T>(
111111
input,
112112
input_scale,
113113
input_zero_point,
@@ -120,7 +120,7 @@ void quantized_layer_norm_(
120120
}
121121

122122
void quantized_layer_norm_out(
123-
KernelRuntimeContext& ctx,
123+
__ET_UNUSED KernelRuntimeContext& ctx,
124124
const Tensor& input,
125125
const Tensor& in_scale,
126126
const Tensor& in_zero_point,
@@ -157,6 +157,44 @@ void quantized_layer_norm_out(
157157
#undef typed_quantized_layer_norm
158158
}
159159

160+
void quantized_layer_norm_per_tensor_out(
161+
__ET_UNUSED KernelRuntimeContext& ctx,
162+
const Tensor& input,
163+
double in_scale,
164+
int64_t in_zero_point,
165+
__ET_UNUSED const IntArrayRef normalized_shape,
166+
const Tensor& weight,
167+
const Tensor& bias,
168+
double eps,
169+
double output_scale,
170+
int64_t output_zero_point,
171+
Tensor& out) {
172+
#define typed_quantized_layer_norm(ctype, dtype) \
173+
case ScalarType::dtype: { \
174+
quantized_layer_norm_per_tensor_<ctype>( \
175+
input, \
176+
in_scale, \
177+
in_zero_point, \
178+
weight, \
179+
bias, \
180+
eps, \
181+
output_scale, \
182+
output_zero_point, \
183+
out); \
184+
break; \
185+
}
186+
187+
ScalarType dtype = input.scalar_type();
188+
switch (dtype) {
189+
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_layer_norm)
190+
default:
191+
ET_DCHECK_MSG(
192+
false, "Unhandled dtype %s", torch::executor::toString(dtype));
193+
}
194+
195+
#undef typed_quantized_layer_norm
196+
}
197+
160198
}; // namespace native
161199
}; // namespace HiFi
162200
}; // namespace impl

backends/cadence/reference/operators/quantized_layer_norm.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111

1212
#include <cmath>
1313

14-
using executorch::aten::Tensor;
15-
using executorch::runtime::getLeadingDims;
16-
using executorch::runtime::KernelRuntimeContext;
14+
using ::executorch::aten::IntArrayRef;
15+
using ::executorch::aten::ScalarType;
16+
using ::executorch::aten::Tensor;
17+
using ::executorch::runtime::getLeadingDims;
18+
using ::executorch::runtime::KernelRuntimeContext;
1719

1820
namespace impl {
1921
namespace reference {
@@ -22,7 +24,7 @@ namespace native {
2224
// Compute quantized layer_norm. The current implementation assumes that the
2325
// input is per-tensor quantized.
2426
template <typename T>
25-
void quantized_layer_norm_(
27+
void quantized_layer_norm_per_tensor_(
2628
const Tensor& input,
2729
double input_scale,
2830
int64_t input_zero_point,
@@ -98,7 +100,7 @@ void quantized_layer_norm_(
98100
int64_t input_zero_point = in_zero_point.const_data_ptr<int64_t>()[0];
99101

100102
// Call other overload
101-
quantized_layer_norm_<T>(
103+
quantized_layer_norm_per_tensor_<T>(
102104
input,
103105
input_scale,
104106
input_zero_point,
@@ -111,11 +113,11 @@ void quantized_layer_norm_(
111113
}
112114

113115
void quantized_layer_norm_out(
114-
KernelRuntimeContext& ctx,
116+
__ET_UNUSED KernelRuntimeContext& ctx,
115117
const Tensor& input,
116118
const Tensor& in_scale,
117119
const Tensor& in_zero_point,
118-
const executorch::aten::IntArrayRef normalized_shape,
120+
__ET_UNUSED const executorch::aten::IntArrayRef normalized_shape,
119121
const Tensor& weight,
120122
const Tensor& bias,
121123
double eps,
@@ -152,6 +154,48 @@ void quantized_layer_norm_out(
152154
}
153155
}
154156

157+
void quantized_layer_norm_per_tensor_out(
158+
__ET_UNUSED KernelRuntimeContext& ctx,
159+
const Tensor& input,
160+
double in_scale,
161+
int64_t in_zero_point,
162+
__ET_UNUSED const executorch::aten::IntArrayRef normalized_shape,
163+
const Tensor& weight,
164+
const Tensor& bias,
165+
double eps,
166+
double output_scale,
167+
int64_t output_zero_point,
168+
Tensor& out) {
169+
if (input.scalar_type() == executorch::aten::ScalarType::Byte) {
170+
quantized_layer_norm_per_tensor_<uint8_t>(
171+
input,
172+
in_scale,
173+
in_zero_point,
174+
weight,
175+
bias,
176+
eps,
177+
output_scale,
178+
output_zero_point,
179+
out);
180+
} else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
181+
quantized_layer_norm_per_tensor_<int8_t>(
182+
input,
183+
in_scale,
184+
in_zero_point,
185+
weight,
186+
bias,
187+
eps,
188+
output_scale,
189+
output_zero_point,
190+
out);
191+
} else {
192+
ET_CHECK_MSG(
193+
false,
194+
"Unhandled input dtype %hhd",
195+
static_cast<int8_t>(input.scalar_type()));
196+
}
197+
}
198+
155199
}; // namespace native
156200
}; // namespace reference
157201
}; // namespace impl

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,16 @@ class TagMemoryMetaPass(ExportPass):
6161
necessary.
6262
"""
6363

64-
def __init__(self):
64+
def __init__(
65+
self,
66+
texture_limits: utils.ImageExtents,
67+
default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D,
68+
default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED,
69+
):
6570
super().__init__()
66-
self.default_storage: VkStorageType = VkStorageType.DEFAULT_STORAGE
67-
self.default_layout: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT
68-
self.texture_limits = (16384, 16384, 2048)
71+
self.default_storage: VkStorageType = default_storage_type
72+
self.default_layout: VkMemoryLayout = default_memory_layout
73+
self.texture_limits = texture_limits
6974

7075
def propose_node_storage(
7176
self,

backends/vulkan/op_registry.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def __init__(
5555
self.valid_packed_dims = valid_packed_dims
5656

5757
def valid_memory_layouts(self) -> Set[VkMemoryLayout]:
58+
"""
59+
Derive the set of memory layouts supported by the texture implementation based
60+
on the valid packed dimensions.
61+
"""
5862
layouts = set()
5963

6064
if PackedDim.WIDTH in self.valid_packed_dims:
@@ -112,6 +116,15 @@ def __init__(
112116
self.check_node_fn = check_node_fn
113117

114118
def propose_storage_type(self) -> Optional[VkStorageType]:
119+
"""
120+
Propose a storage type that should be used for this operator. A proposal can be
121+
made if one of the following is true:
122+
1. The operator specifies an optimal storage type
123+
2. Only one storage type is supported.
124+
125+
If both storage types are supported and no optimal storage type is specified,
126+
then None is returned to indicate that there is no preference in storage type.
127+
"""
115128
if self.optimal_storage is not None:
116129
return self.optimal_storage
117130

@@ -123,6 +136,9 @@ def propose_storage_type(self) -> Optional[VkStorageType]:
123136
return None
124137

125138
def supported_storage_types(self) -> Set[VkStorageType]:
139+
"""
140+
Return the set of storage types supported by this operator.
141+
"""
126142
storage_types = set()
127143
if self.texture_impl is not None:
128144
storage_types.add(VkStorageType.TEXTURE_3D)
@@ -132,6 +148,16 @@ def supported_storage_types(self) -> Set[VkStorageType]:
132148
return storage_types
133149

134150
def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]:
151+
"""
152+
Given a storage type as a precondition, propose a memory layout that should be
153+
used for this operator. A proposal can be made if one of the following is true:
154+
1. The operator specifies an optimal memory layout
155+
2. Only one memory layout is supported.
156+
157+
If multiple memory layouts are supported and no optimal memory layout is
158+
specified then return None to indicate that the "best" memory layout for the
159+
operator is ambiguous.
160+
"""
135161
if self.optimal_layout is not None:
136162
return self.optimal_layout
137163

@@ -144,6 +170,10 @@ def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayo
144170
return None
145171

146172
def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]:
173+
"""
174+
Return the set of memory layouts supported by this operator for a given storage
175+
type.
176+
"""
147177
if storage == VkStorageType.TEXTURE_3D:
148178
assert self.texture_impl is not None
149179
return self.texture_impl.valid_memory_layouts()
@@ -517,13 +547,5 @@ def get_op_features(target: OpKey) -> OpFeatures:
517547
return vulkan_supported_ops[target]
518548

519549

520-
def optimal_storage_type(target: OpKey) -> Optional[VkStorageType]:
521-
return get_op_features(target).optimal_storage
522-
523-
524-
def optimal_memory_layout(target: OpKey) -> Optional[VkMemoryLayout]:
525-
return get_op_features(target).optimal_layout
526-
527-
528550
def handles_own_prepacking(target: OpKey) -> bool:
529551
return get_op_features(target).handles_own_prepacking

0 commit comments

Comments
 (0)