Skip to content

Commit 42da8ed

Browse files
committed
Update base for Update on "[wip] Export lora weights to sep file"
Differential Revision: [D83777195](https://our.internmc.facebook.com/intern/diff/D83777195/) [ghstack-poisoned]
2 parents 3bcbc5d + a39866c commit 42da8ed

File tree

17 files changed

+561
-163
lines changed

17 files changed

+561
-163
lines changed

.github/workflows/trunk.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ jobs:
289289
- test_arm_baremetal: test_models_ethos-u55
290290
- test_arm_baremetal: test_models_ethos-u85
291291
- test_arm_baremetal: test_smaller_stories_llama
292+
- test_arm_baremetal: test_memory_allocation
292293
fail-fast: false
293294
with:
294295
runner: linux.2xlarge.memory

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
NNCHW_ORDER,
2727
NNHWC_INVERSE_ORDER,
2828
NNHWC_ORDER,
29+
NNNCHW_ORDER,
30+
NNNHWC_INVERSE_ORDER,
31+
NNNHWC_ORDER,
2932
)
3033
from executorch.exir import ExportedProgram
3134
from executorch.exir.dialects._ops import ops as exir_ops
@@ -51,12 +54,6 @@ class ToTosaMemoryFormatPass(ExportPass):
5154

5255
_passes_required_after: Set[Type[ExportPass]] = set()
5356

54-
NHWC_order = (0, 2, 3, 1)
55-
NHWC_inverse_order = (0, 3, 1, 2)
56-
HWCM_order = (2, 3, 0, 1)
57-
NNHWC_order = (0, 1, 3, 4, 2)
58-
NNHWC_inverse_order = (0, 1, 4, 2, 3)
59-
6057
def __init__(self, exported_program: ExportedProgram) -> None:
6158
self.exported_program = exported_program
6259
super().__init__()
@@ -93,7 +90,11 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
9390
@staticmethod
9491
def memory_format_differs(shape):
9592
"""Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
96-
if len(shape) >= 5:
93+
if len(shape) >= 6:
94+
C = shape[3]
95+
H = shape[4]
96+
W = shape[5]
97+
elif len(shape) == 5:
9798
C = shape[2]
9899
H = shape[3]
99100
W = shape[4]
@@ -112,25 +113,26 @@ def memory_format_differs(shape):
112113

113114
@staticmethod
114115
def is_channel_reshape(input_shape, output_shape):
115-
"""Returns true if the reshape changes the channel dimension"""
116-
if not (
117-
(len(input_shape) == len(output_shape) and (len(output_shape) in (4, 5)))
118-
or (len(input_shape) == 4 and len(output_shape) == 5)
119-
or (len(input_shape) == 5 and len(output_shape) == 4)
120-
):
116+
"""Returns true if reshape changes the channel dimension or batch product dimension(s)"""
117+
118+
valid_ranks = {4, 5, 6}
119+
120+
if not (len(input_shape) in valid_ranks and len(output_shape) in valid_ranks):
121121
return False
122122

123123
C_old = input_shape[-3]
124124
C_new = output_shape[-3]
125125

126-
N_new = (
127-
output_shape[0]
128-
if len(output_shape) == 4
129-
else output_shape[0] * output_shape[1]
130-
)
131-
N_old = (
132-
input_shape[0] if len(input_shape) == 4 else input_shape[0] * input_shape[1]
133-
)
126+
def get_batch_prod_dim(shape):
127+
product = 1
128+
129+
for dim in shape[:-3]:
130+
product = product * dim
131+
132+
return product
133+
134+
N_old = get_batch_prod_dim(input_shape)
135+
N_new = get_batch_prod_dim(output_shape)
134136

135137
return (N_old != N_new) or (C_old != C_new)
136138

@@ -141,17 +143,27 @@ def insert_input_transpose(node, input_node, graph_module):
141143
node.replace_input_with(input_node, pre_permute_node)
142144
return
143145

146+
if len(get_first_fake_tensor(input_node).size()) == 6:
147+
mem_format = NNNHWC_INVERSE_ORDER
148+
elif len(get_first_fake_tensor(input_node).size()) == 5:
149+
mem_format = NNHWC_INVERSE_ORDER
150+
else:
151+
mem_format = NHWC_INVERSE_ORDER
152+
# Guard: mem_format must be a true permutation for the current rank
153+
_rank_ = len(
154+
get_first_fake_tensor(input_node).size()
155+
) # or (node) in output path
156+
assert sorted(mem_format) == list(
157+
range(_rank_)
158+
), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose"
159+
144160
with graph_module.graph.inserting_before(node):
145161
permute_node = create_node(
146162
graph_module.graph,
147163
exir_ops.backend.tosa.TRANSPOSE.default,
148164
args=(
149165
input_node,
150-
list(
151-
NNHWC_INVERSE_ORDER
152-
if len(get_first_fake_tensor(input_node).size()) == 5
153-
else NHWC_INVERSE_ORDER
154-
),
166+
list(mem_format),
155167
),
156168
from_node=node,
157169
)
@@ -163,26 +175,38 @@ def insert_input_transpose(node, input_node, graph_module):
163175

164176
@staticmethod
165177
def insert_output_transpose(node, graph_module):
178+
179+
if len(get_first_fake_tensor(node).size()) == 6:
180+
mem_format = NNNHWC_ORDER
181+
elif len(get_first_fake_tensor(node).size()) == 5:
182+
mem_format = NNHWC_ORDER
183+
else:
184+
mem_format = NHWC_ORDER
185+
# Guard: mem_format must be a true permutation for the current rank
186+
_rank_ = len(get_first_fake_tensor(node).size()) # or (node) in output path
187+
assert sorted(mem_format) == list(
188+
range(_rank_)
189+
), f"bad perm {mem_format} for rank {_rank_} in insert_input_transpose"
190+
166191
with graph_module.graph.inserting_after(node):
167192
permute_node = create_node(
168193
graph_module.graph,
169194
exir_ops.backend.tosa.TRANSPOSE.default,
170195
args=(
171196
node,
172-
list(
173-
NNHWC_ORDER
174-
if len(get_first_fake_tensor(node).size()) == 5
175-
else NHWC_ORDER
176-
),
197+
list(mem_format),
177198
),
178199
from_node=node,
179200
)
180201

181-
permute_node.meta["tosa_dim_order"] = (
182-
NNHWC_ORDER
183-
if len(get_first_fake_tensor(node).size()) == 5
184-
else NHWC_ORDER
185-
)
202+
rank = len(get_first_fake_tensor(node).size())
203+
if rank == 6:
204+
permute_node.meta["tosa_dim_order"] = NNNHWC_ORDER
205+
elif rank == 5:
206+
permute_node.meta["tosa_dim_order"] = NNHWC_ORDER
207+
else:
208+
permute_node.meta["tosa_dim_order"] = NHWC_ORDER
209+
186210
node.meta["tosa_dim_order"] = tuple(
187211
range(len(get_first_fake_tensor(node).size()))
188212
)
@@ -261,7 +285,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
261285
]
262286
for input_node in inputs:
263287
input_dim_order = get_first_fake_tensor(input_node).dim_order()
264-
if input_dim_order in (NCHW_ORDER, NNCHW_ORDER):
288+
if input_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER):
265289
self.insert_output_transpose(input_node, graph_module)
266290

267291
# Transpose outputs if they are in (N)NCHW format
@@ -276,6 +300,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
276300
if output_dim_order in (
277301
NCHW_ORDER,
278302
NNCHW_ORDER,
303+
NNNCHW_ORDER,
279304
):
280305
self.insert_input_transpose(
281306
output_node, output_node_input, graph_module
@@ -313,6 +338,8 @@ def call(self, graph_module: torch.fx.GraphModule):
313338
dim_order = HWCM_ORDER
314339
elif node_data.dim() == 5:
315340
dim_order = NNHWC_ORDER
341+
elif node_data.dim() == 6:
342+
dim_order = NNNHWC_ORDER
316343
else:
317344
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
318345

backends/arm/constants.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@
3434
NHWC_INVERSE_ORDER: Final = (0, 3, 1, 2)
3535
NNHWC_ORDER: Final = (0, 1, 3, 4, 2)
3636
NNHWC_INVERSE_ORDER: Final = (0, 1, 4, 2, 3)
37+
NNNHWC_ORDER: Final = (0, 1, 2, 4, 5, 3)
38+
NNNHWC_INVERSE_ORDER: Final = (0, 1, 2, 5, 3, 4)
3739

3840
NCHW_ORDER: Final = (0, 1, 2, 3)
39-
NCHW_INVERSE_ORDER: Final = (0, 2, 3, 1)
4041
NNCHW_ORDER: Final = (0, 1, 2, 3, 4)
41-
NNCHW_INVERSE_ORDER: Final = (0, 1, 3, 4, 2)
42+
NNNCHW_ORDER: Final = (0, 1, 2, 3, 4, 5)
4243

4344
HWCM_ORDER: Final = (2, 3, 0, 1)
45+
46+
MAX_RANK: Final = 6

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
# INT profile: ops supported via native TOSA ops, decompositions/transformations, precompute, TableOps, etc.
21+
# Note that ops supported via pre-quantization decompositions are not included here.
2122
TOSA_PRO_INT_SupportList: Final[Set] = {
2223
exir_ops.edge.aten.abs.default,
2324
exir_ops.edge.aten.add.Tensor,
@@ -46,8 +47,6 @@
4647
exir_ops.edge.aten.hardsigmoid.default,
4748
exir_ops.edge.aten.hardtanh.default,
4849
exir_ops.edge.aten.hardswish.default,
49-
exir_ops.edge.aten.div.Tensor,
50-
exir_ops.edge.aten.div.Tensor_mode,
5150
exir_ops.edge.aten.eq.Tensor,
5251
exir_ops.edge.aten.eq.Scalar,
5352
exir_ops.edge.aten.erf.default,
@@ -68,16 +67,7 @@
6867
exir_ops.edge.aten.lt.Tensor,
6968
exir_ops.edge.aten.lt.Scalar,
7069
exir_ops.edge.aten.mul.Tensor,
71-
exir_ops.edge.aten.ne.Tensor,
72-
exir_ops.edge.aten.ne.Scalar,
7370
exir_ops.edge.aten.neg.default,
74-
exir_ops.edge.aten.add.Scalar,
75-
exir_ops.edge.aten.sub.Scalar,
76-
exir_ops.edge.aten.mul.Scalar,
77-
exir_ops.edge.aten.div.Scalar,
78-
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
79-
exir_ops.edge.aten.native_layer_norm.default,
80-
exir_ops.edge.aten.native_group_norm.default,
8171
exir_ops.edge.aten.sigmoid.default,
8272
exir_ops.edge.aten.mean.dim,
8373
exir_ops.edge.aten.mm.default,
@@ -86,19 +76,12 @@
8676
exir_ops.edge.aten.repeat.default,
8777
exir_ops.edge.aten.reciprocal.default,
8878
exir_ops.edge.aten.relu.default,
89-
exir_ops.edge.aten.leaky_relu.default,
90-
exir_ops.edge.aten.sqrt.default,
9179
exir_ops.edge.aten.rsqrt.default,
92-
exir_ops.edge.aten.round.default,
93-
exir_ops.edge.aten._softmax.default,
9480
exir_ops.edge.aten.select_copy.int,
95-
exir_ops.edge.aten._log_softmax.default,
9681
exir_ops.edge.aten.sub.Tensor,
9782
exir_ops.edge.aten.tanh.default,
9883
exir_ops.edge.aten.upsample_bilinear2d.vec,
9984
exir_ops.edge.aten.upsample_nearest2d.vec,
100-
exir_ops.edge.aten.var.correction,
101-
exir_ops.edge.aten.var.dim,
10285
exir_ops.edge.aten.view_copy.default,
10386
exir_ops.edge.aten.unsqueeze_copy.default,
10487
exir_ops.edge.aten.squeeze_copy.dims,
@@ -127,12 +110,9 @@
127110
exir_ops.edge.aten.sign.default,
128111
exir_ops.edge.aten.asin.default,
129112
exir_ops.edge.aten.atanh.default,
130-
exir_ops.edge.aten.addmm.default,
131113
exir_ops.edge.aten.masked_fill.Scalar,
132114
exir_ops.edge.aten.asinh.default,
133115
exir_ops.edge.aten.cosh.default,
134-
exir_ops.edge.aten.glu.default,
135-
exir_ops.edge.aten.logit.default,
136116
exir_ops.edge.aten.acos.default,
137117
exir_ops.edge.aten.elu.default,
138118
exir_ops.edge.aten.bitwise_not.default,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
FuseQuantizedActivationPass,
2020
)
2121
from executorch.backends.arm._passes.insert_table_ops import TableOps
22-
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
22+
from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS
2323
from executorch.backends.arm.operator_support.ethos_u55_support import (
2424
EthosU55CastCheck,
2525
EthosU55DtypeSupport,
@@ -127,15 +127,14 @@ def tosa_support_factory(
127127
negative_checks: list[OperatorSupportBase] = [
128128
CheckInt64InputsAndOutputs(exported_program, reporter),
129129
CheckFloat64Inputs(exported_program, reporter),
130-
RankCheck(reporter, max_rank=5),
130+
RankCheck(reporter, max_rank=MAX_RANK),
131131
*[
132132
reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
133133
for check in (additional_checks if additional_checks else [])
134134
],
135135
]
136136

137137
if not tosa_spec.support_float():
138-
negative_checks.append(NeedsDecompositionCheck(reporter))
139138
negative_checks.append(CheckProperQuantization(reporter))
140139
if tosa_spec.is_U55_subset:
141140
negative_checks.append(EthosU55NotSupported(reporter))
@@ -156,7 +155,8 @@ def tosa_support_factory(
156155
class TOSAProINTSupportList(OperatorSupportBase):
157156
"""
158157
TOSA_PRO_INT_SupportList:
159-
Ops supported in INT profile via native TOSA ops, decomposition/transformation, pre-compute, or TableOps
158+
Ops supported in INT profile via native TOSA ops, decomposition/transformation, pre-compute, or TableOps.
159+
Note that ops supported via pre-quantization decompositions are not included here.
160160
"""
161161

162162
def is_node_supported(
@@ -179,57 +179,6 @@ def is_node_supported(
179179
return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList
180180

181181

182-
class NeedsDecompositionCheck(OperatorSupportBase):
183-
"""
184-
Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding
185-
the operator, and to get optimal quantization parameters for each operator. This check will reject operators
186-
that need to be decomposed.
187-
"""
188-
189-
def __init__(self, reporter: WhyNoPartitionReporter):
190-
self.reporter = reporter
191-
192-
def is_node_supported(
193-
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
194-
) -> bool:
195-
196-
if node.op != "call_function":
197-
return True
198-
199-
needs_decomp_dict = {
200-
exir_ops.edge.aten.div.Tensor: None,
201-
exir_ops.edge.aten._native_batch_norm_legit_no_training.default: "BatchNorm2D with track_running_stats==True not immediately following a convolution is not supported for quantized TOSA backends.",
202-
exir_ops.edge.aten.native_layer_norm.default: None,
203-
exir_ops.edge.aten.native_group_norm.default: None,
204-
exir_ops.edge.aten._softmax.default: None,
205-
exir_ops.edge.aten._log_softmax.default: None,
206-
exir_ops.edge.aten.var.correction: None,
207-
exir_ops.edge.aten.var.dim: None,
208-
exir_ops.edge.aten.add.Scalar: None,
209-
exir_ops.edge.aten.sqrt.default: None,
210-
exir_ops.edge.aten.sub.Scalar: None,
211-
exir_ops.edge.aten.mul.Scalar: None,
212-
exir_ops.edge.aten.ne.Tensor: None,
213-
exir_ops.edge.aten.ne.Scalar: None,
214-
exir_ops.edge.aten.div.Scalar: None,
215-
exir_ops.edge.aten.leaky_relu.default: None,
216-
exir_ops.edge.aten.round.default: None,
217-
exir_ops.edge.aten.addmm.default: None,
218-
exir_ops.edge.aten.glu.default: None,
219-
exir_ops.edge.aten.logit.default: None,
220-
}
221-
222-
if node.target in needs_decomp_dict:
223-
reject_message = needs_decomp_dict[node.target]
224-
if reject_message is None:
225-
reject_message = "Op needs to be decomposed into other ops before quantization to get quantized properly."
226-
227-
self.reporter.report_reject(node, reject_message)
228-
return False
229-
else:
230-
return True
231-
232-
233182
class CheckProperQuantization(OperatorSupportBase):
234183
"""
235184
For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ def _match_pattern(
370370
torch.ops.aten.dropout_.default,
371371
torch.ops.aten.adaptive_avg_pool2d.default,
372372
torch.ops.aten.alias_copy.default,
373+
torch.ops.aten.pixel_shuffle.default,
374+
torch.ops.aten.pixel_unshuffle.default,
373375
]
374376

375377

backends/arm/scripts/parse_test_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
"_native_batch_norm_legit_no_training.default",
2727
"_native_batch_norm_legit.no_stats",
2828
"alias_copy.default",
29+
"pixel_shuffle.default",
30+
"pixel_unshuffle.default",
2931
]
3032
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
3133

0 commit comments

Comments
 (0)