Skip to content

Commit b215780

Browse files
chunnienccopybara-github
authored andcommitted
Update NHWC rewriter for native_group_norm.
PiperOrigin-RevId: 718944411
1 parent dbb82b3 commit b215780

File tree

3 files changed

+81
-74
lines changed

3 files changed

+81
-74
lines changed

ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,22 +201,25 @@ def _aten_group_norm_checker(node):
201201
return NHWCable(can_be=can_be, must_be=must_be)
202202

203203

204-
@nhwcable_node_checkers.register(aten.native_group_norm)
204+
@nhwcable_node_checkers.register(aten.native_group_norm.default)
205205
def _aten_native_group_norm_checker(node):
206+
# aten.group_norm is removed from the decomp table, so aten.native_group_norm
207+
# should never exist in the graph. However, torch 2.5.1 could ignore the
208+
# decomp table updates, so still add this native_group_norm checker and
209+
# rewriter to be safe.
210+
# The checker and rewriter are the same as the ones for aten.group_norm.
211+
206212
val = node.meta.get("val")
207213
if (
208214
not isinstance(val, (list, tuple))
209215
or not val
210216
or not hasattr(val[0], "shape")
211217
):
212218
return NHWCable(can_be=False, must_be=False)
213-
if len(node.args) >= 3 and (
214-
node.args[1] is not None or node.args[2] is not None
215-
):
216-
# Disable NHWC rewriter due to precision issue with weight and bias.
217-
# TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
218-
return NHWCable(can_be=False, must_be=False)
219-
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
219+
220+
can_be = len(val[0].shape) == 4
221+
must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
222+
return NHWCable(can_be=can_be, must_be=must_be)
220223

221224

222225
# ==== Ops must be NCHW

ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -391,34 +391,32 @@ def native_group_norm(
391391
eps: float,
392392
**kwargs,
393393
):
394-
input_reshaped = torch.reshape(
395-
input,
396-
[
397-
batch_size,
398-
flattened_inner_size,
399-
num_groups,
400-
num_channels // num_groups,
401-
],
402-
)
403-
reduction_dims = [1, 3]
404-
405-
biased_var, mean = torch.var_mean(
406-
input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True
394+
is_composite_supported = (
395+
ai_edge_torch.config.enable_group_norm_composite
396+
and weight is not None
397+
and bias is not None
407398
)
408-
rstd = torch.rsqrt(biased_var + eps)
409-
410-
out = (input_reshaped - mean) * rstd
411-
out = torch.reshape(out, input.shape)
412399

413-
if weight is not None:
414-
out = out * weight
415-
if bias is not None:
416-
out = out + bias
400+
builder = None
401+
if is_composite_supported:
402+
builder = StableHLOCompositeBuilder(
403+
name="odml.group_norm",
404+
attr={
405+
"num_groups": num_groups,
406+
"epsilon": eps,
407+
"reduction_axes": [3],
408+
"channel_axis": 3,
409+
},
410+
)
411+
input, weight, bias = builder.mark_inputs(input, weight, bias)
417412

418-
mean = torch.squeeze(mean, reduction_dims)
419-
rstd = torch.squeeze(rstd, reduction_dims)
413+
input = utils.tensor_to_nchw(input)
414+
output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
415+
output = utils.tensor_to_nhwc(output)
420416

421-
return out, mean, rstd
417+
if builder is not None:
418+
output = builder.mark_outputs(output)
419+
return (output, None, None)
422420

423421
node.target = native_group_norm
424422

ai_edge_torch/_convert/fx_passes/test/test_optimize_layout_transposes_pass.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
from typing import Callable, Union
1818

19+
import ai_edge_torch
1920
from ai_edge_torch import fx_infra
21+
from ai_edge_torch import lowertools
2022
from ai_edge_torch._convert import fx_passes
2123
import torch
2224
import torch.utils._pytree as pytree
@@ -48,10 +50,7 @@ def forward(self, *args, **kwargs):
4850
)
4951
exported_program = fx_infra.run_passes(
5052
exported_program,
51-
[
52-
fx_passes.OptimizeLayoutTransposesPass(),
53-
fx_passes.CanonicalizePass(),
54-
],
53+
[fx_passes.OptimizeLayoutTransposesPass()],
5554
)
5655
return exported_program
5756

@@ -90,27 +89,19 @@ def test_torchvision_resnet18(self):
9089
model, exported_program.module(), forward_args()
9190
)
9291

93-
def test_native_group_norm_no_weight_bias(self):
94-
batch_size = 16
95-
num_channels = 640
96-
flattened_inner_size = 256
97-
num_groups = 32
98-
eps = 1e-6
92+
def test_group_norm_affine_false(self):
9993

10094
class SampleModel(torch.nn.Module):
10195

96+
def __init__(self):
97+
super().__init__()
98+
self.group_norm = torch.nn.GroupNorm(
99+
num_groups=32, num_channels=640, affine=False, eps=1e-6
100+
)
101+
102102
def forward(self, x):
103103
x = torch.nn.AvgPool2d(2)(x)
104-
x = torch.ops.aten.native_group_norm(
105-
x,
106-
None,
107-
None,
108-
batch_size,
109-
num_channels,
110-
flattened_inner_size,
111-
num_groups,
112-
eps,
113-
)[0]
104+
x = self.group_norm(x)
114105
x = torch.nn.AvgPool2d(2)(x)
115106
return x
116107

@@ -121,41 +112,56 @@ def forward(self, x):
121112
model, exported_program.module(), forward_args()
122113
)
123114

124-
def test_native_group_norm_large_weight_bias(self):
125-
batch_size = 16
126-
num_channels = 640
127-
flattened_inner_size = 256
128-
num_groups = 32
129-
eps = 1e-6
115+
def test_group_norm_large_affine_true(self):
130116

131117
class SampleModel(torch.nn.Module):
132118

133-
def forward(self, x, weight, bias):
119+
def __init__(self):
120+
super().__init__()
121+
self.group_norm = torch.nn.GroupNorm(
122+
num_groups=32, num_channels=640, affine=True, eps=1e-6
123+
)
124+
125+
def forward(self, x):
134126
x = torch.nn.AvgPool2d(2)(x)
135-
x = torch.ops.aten.native_group_norm(
136-
x,
137-
weight,
138-
bias,
139-
batch_size,
140-
num_channels,
141-
flattened_inner_size,
142-
num_groups,
143-
eps,
144-
)[0]
127+
x = self.group_norm(x)
145128
x = torch.nn.AvgPool2d(2)(x)
146129
return x
147130

148131
model = SampleModel().eval()
149-
forward_args = lambda: (
150-
torch.rand(16, 640, 32, 32) * 1000,
151-
torch.rand([640]) * 1000,
152-
torch.rand([640]) * 1000,
132+
forward_args = lambda: (torch.rand(16, 640, 32, 32) * 1000,)
133+
exported_program = export_with_pass(model, forward_args())
134+
self.assert_outputs_allclose(
135+
model, exported_program.module(), forward_args()
153136
)
137+
138+
def test_group_norm_with_composite_enabled(self):
139+
ai_edge_torch.config.enable_group_norm_composite = True
140+
141+
class SampleModel(torch.nn.Module):
142+
143+
def __init__(self):
144+
super().__init__()
145+
self.group_norm = torch.nn.GroupNorm(
146+
num_groups=2, num_channels=10, affine=True
147+
)
148+
149+
def forward(self, x):
150+
x = torch.nn.AvgPool2d(2)(x)
151+
x = self.group_norm(x)
152+
x = torch.nn.AvgPool2d(2)(x)
153+
return x
154+
155+
model = SampleModel().eval()
156+
forward_args = lambda: (torch.rand(1, 10, 32, 32),)
154157
exported_program = export_with_pass(model, forward_args())
155158
self.assert_outputs_allclose(
156159
model, exported_program.module(), forward_args()
157160
)
158161

162+
ir_text = lowertools.exported_program_to_mlir_text(exported_program)
163+
self.assertEqual(ir_text.count("stablehlo.custom_call @mark_tensor"), 4)
164+
159165

160-
if __name__ == '__main__':
166+
if __name__ == "__main__":
161167
googletest.main()

0 commit comments

Comments
 (0)