Skip to content

Commit 58a7cde

Browse files
chunnienccopybara-github
authored andcommitted
add layout opt rule for group norm
PiperOrigin-RevId: 712372815
1 parent 7e68dce commit 58a7cde

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
155155
@layout_sensitive_inputs_getters.register(
156156
aten._native_batch_norm_legit_no_training
157157
)
158+
@layout_sensitive_inputs_getters.register(aten.group_norm)
158159
@layout_sensitive_inputs_getters.register(aten.native_group_norm)
159160
def _first_arg_getter(node):
160161
return [node.args[0]]
@@ -188,6 +189,14 @@ def _aten_norm_checker(node):
188189
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
189190

190191

192+
@nhwcable_node_checkers.register(aten.group_norm)
193+
def _aten_group_norm_checker(node):
194+
val = node.meta.get("val")
195+
if not hasattr(val, "shape"):
196+
return NHWCable(can_be=False, must_be=False)
197+
return NHWCable(can_be=len(val.shape) == 4, must_be=False)
198+
199+
191200
@nhwcable_node_checkers.register(aten.native_group_norm)
192201
def _aten_native_group_norm_checker(node):
193202
val = node.meta.get("val")

ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,18 @@ def batch_norm(input, weight, bias, running_mean, running_var, momentum, eps):
342342
node.target = batch_norm
343343

344344

345+
@rewriters.register(aten.group_norm.default)
346+
def _aten_group_norm(node):
347+
def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
348+
# Disable NHWC rewriter with native decomposied ops due to precision issue.
349+
# TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
350+
input = utils.tensor_to_nchw(input)
351+
res = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
352+
return utils.tensor_to_nhwc(res)
353+
354+
node.target = group_norm
355+
356+
345357
@rewriters.register(aten.native_group_norm.default)
346358
def _aten_native_group_norm(node):
347359

@@ -354,6 +366,7 @@ def native_group_norm(
354366
flattened_inner_size: int,
355367
num_groups: int,
356368
eps: float,
369+
**kwargs,
357370
):
358371
input_reshaped = torch.reshape(
359372
input,

0 commit comments

Comments
 (0)