Skip to content

Commit b183411

Browse files
chunnienccopybara-github
authored andcommitted
Add enable_group_norm_composite global flag
PiperOrigin-RevId: 712602061
1 parent b9c7180 commit b183411

File tree

4 files changed

+60
-64
lines changed

4 files changed

+60
-64
lines changed

ai_edge_torch/_config.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@
2222
__all__ = ["config"]
2323

2424

25+
def _get_bool_env_var(name: str, default: bool) -> bool:
26+
var = os.environ.get(name, "false")
27+
var = var.lower().strip()
28+
if var in ("y", "yes", "t", "true", "on", "1"):
29+
return True
30+
elif var in ("n", "no", "f", "false", "off", "0"):
31+
return False
32+
else:
33+
logging.warning("Invalid %s value is ignored: %s.", name, var)
34+
return default
35+
36+
2537
class _Config:
2638
"""ai-edge-torch global configs."""
2739

@@ -33,20 +45,25 @@ def use_torch_xla(self) -> bool:
3345
To use torch_xla as the lowering backend, set environment variable
3446
`USE_TORCH_XLA` to "true".
3547
"""
36-
var = os.environ.get("USE_TORCH_XLA", "false")
37-
var = var.lower().strip()
38-
if var in ("y", "yes", "t", "true", "on", "1"):
39-
return True
40-
elif var in ("n", "no", "f", "false", "off", "0"):
41-
return False
42-
else:
43-
logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
44-
return False
48+
return _get_bool_env_var("USE_TORCH_XLA", default=False)
4549

4650
@property
4751
def in_oss(self) -> bool:
4852
"""True if the code is not running in google internal environment."""
4953
return True
5054

55+
@property
56+
def enable_group_norm_composite(self) -> bool:
57+
"""True if lowering group norm in StableHLO composite.
58+
59+
Currently only supports NHWC group norm generated by
60+
OptimizeLayoutTransposesPass.
61+
"""
62+
return _get_bool_env_var("ENABLE_GROUP_NORM_COMPOSITE", default=False)
63+
64+
@enable_group_norm_composite.setter
65+
def enable_group_norm_composite(self, value: bool):
66+
os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
67+
5168

5269
config = _Config()

ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dataclasses
1818
import operator
1919

20+
import ai_edge_torch
2021
from ai_edge_torch import lowertools
2122
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
2223
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -194,7 +195,10 @@ def _aten_group_norm_checker(node):
194195
val = node.meta.get("val")
195196
if not hasattr(val, "shape"):
196197
return NHWCable(can_be=False, must_be=False)
197-
return NHWCable(can_be=len(val.shape) == 4, must_be=False)
198+
199+
can_be = len(val.shape) == 4
200+
must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
201+
return NHWCable(can_be=can_be, must_be=must_be)
198202

199203

200204
@nhwcable_node_checkers.register(aten.native_group_norm)

ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616

1717
import operator
1818

19+
import ai_edge_torch
1920
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
2021
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
2122
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
2223
import torch
2324
import torch.utils._pytree as pytree
2425

2526
aten = torch.ops.aten
27+
StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
2628

2729
__all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
2830

@@ -345,11 +347,32 @@ def batch_norm(input, weight, bias, running_mean, running_var, momentum, eps):
345347
@rewriters.register(aten.group_norm.default)
346348
def _aten_group_norm(node):
347349
def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
348-
# Disable NHWC rewriter with native decomposied ops due to precision issue.
349-
# TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
350+
is_composite_supported = (
351+
ai_edge_torch.config.enable_group_norm_composite
352+
and weight is not None
353+
and bias is not None
354+
)
355+
356+
builder = None
357+
if is_composite_supported:
358+
builder = StableHLOCompositeBuilder(
359+
name="odml.group_norm",
360+
attr={
361+
"num_groups": num_groups,
362+
"epsilon": eps,
363+
"reduction_axes": [3],
364+
"channel_axis": 3,
365+
},
366+
)
367+
input, weight, bias = builder.mark_inputs(input, weight, bias)
368+
350369
input = utils.tensor_to_nchw(input)
351-
res = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
352-
return utils.tensor_to_nhwc(res)
370+
output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
371+
output = utils.tensor_to_nhwc(output)
372+
373+
if builder is not None:
374+
output = builder.mark_outputs(output)
375+
return output
353376

354377
node.target = group_norm
355378

ai_edge_torch/generative/layers/normalization.py

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def forward(self, x):
8080
output = self._norm(x.float()).type_as(x)
8181
return output * w
8282

83+
8384
class GroupNorm(torch.nn.Module):
8485

8586
def __init__(
@@ -115,16 +116,7 @@ def forward(self, x):
115116
Returns:
116117
torch.Tensor: output tensor after applying GroupNorm.
117118
"""
118-
if self.enable_hlfb:
119-
return group_norm_with_hlfb(
120-
x,
121-
self.weight,
122-
self.bias,
123-
self.group_num,
124-
self.eps,
125-
)
126-
else:
127-
return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
119+
return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
128120

129121

130122
class LayerNorm(torch.nn.Module):
@@ -169,46 +161,6 @@ def forward(self, x):
169161
)
170162

171163

172-
def group_norm_with_hlfb(
173-
x: torch.Tensor,
174-
w: torch.Tensor,
175-
b: torch.Tensor,
176-
num_groups: int,
177-
eps: float,
178-
):
179-
"""Group Normalization with high-level function boundary enabled.
180-
181-
Args:
182-
x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
183-
w (torch.Tensor): The weight tensor for the normalization.
184-
b (torch.Tensor): The bias tensor for the normalization.
185-
num_groups (int): Number of groups to separate the channels into.
186-
eps (float): A small float value to ensure numerical stability.
187-
188-
Returns:
189-
The output tensor of Group Normalization.
190-
"""
191-
x = torch.permute(x, (0, 2, 3, 1))
192-
193-
builder = StableHLOCompositeBuilder(
194-
name="odml.group_norm",
195-
attr={
196-
"num_groups": num_groups,
197-
"epsilon": eps,
198-
"reduction_axes": [3],
199-
"channel_axis": 3,
200-
},
201-
)
202-
x, w, b = builder.mark_inputs(x, w, b)
203-
x = torch.permute(x, (0, 3, 1, 2))
204-
y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
205-
y = torch.permute(y, (0, 2, 3, 1))
206-
y = builder.mark_outputs(y)
207-
208-
y = torch.permute(y, (0, 3, 1, 2))
209-
return y
210-
211-
212164
def rms_norm_with_hlfb(
213165
x: torch.Tensor,
214166
w: torch.Tensor,

0 commit comments

Comments
 (0)