Skip to content

Commit a9955e5

Browse files
authored
[Relax][PyTorch] Add decomposed operator support for normalization (#18460)
## Related Issue - #18401 ## How This PR - added `_batch_norm_legit_no_stats` - added `_native_group_norm` - added `any.dims` - refctored `_reshape`
1 parent 45a2a40 commit a9955e5

File tree

3 files changed

+69
-11
lines changed

3 files changed

+69
-11
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,6 +1848,12 @@ def _reshape(self, node: fx.Node) -> relax.Var:
18481848
args = self.retrieve_args(node)
18491849
x = args[0]
18501850
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:]
1851+
1852+
# Skip identity reshape
1853+
current_shape = self.shape_of(x)
1854+
if list(current_shape) == list(dims):
1855+
return x
1856+
18511857
return self.block_builder.emit(relax.op.reshape(x, dims))
18521858

18531859
def _reshape_as(self, node: fx.Node) -> relax.Var:

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,31 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
113113
training = False
114114
return self._batch_norm(node, training)
115115

116+
def _batch_norm_legit_no_stats(self, node: fx.Node) -> relax.Var:
117+
import numpy as np
118+
119+
x = self.env[node.args[0]]
120+
channel = int(self.shape_of(x)[1])
121+
dtype = x.struct_info.dtype
122+
weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype))
123+
bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype))
124+
eps = node.args[5] if len(node.args) > 5 else node.kwargs.get("eps", 1e-05)
125+
126+
# Determine axes for instance norm (all spatial dimensions after channel)
127+
dim = len(self.shape_of(x))
128+
axes = list(range(2, dim))
129+
130+
return self.block_builder.emit(
131+
relax.op.nn.instance_norm(
132+
x,
133+
weight,
134+
bias,
135+
channel_axis=1,
136+
axes=axes,
137+
epsilon=eps,
138+
)
139+
)
140+
116141
def _cross_entropy_default(self, node: fx.Node) -> relax.Expr:
117142
preds = self.env[node.args[0]]
118143
targets = self.env[node.args[1]]
@@ -141,6 +166,28 @@ def _group_norm(self, node: fx.Node) -> relax.Var:
141166
)
142167
)
143168

169+
def _native_group_norm(self, node: fx.Node) -> relax.Var:
170+
# native_group_norm signature: (input, weight, bias, N, C, HxW, group, eps)
171+
x = self.env[node.args[0]]
172+
gamma = self.env.get(node.args[1], None) if len(node.args) > 1 else None
173+
beta = self.env.get(node.args[2], None) if len(node.args) > 2 else None
174+
# args[3] = N (batch size), args[4] = C (channels), args[5] = HxW (spatial size)
175+
num_groups = node.args[6] if len(node.args) > 6 else 1
176+
eps = node.args[7] if len(node.args) > 7 else 1e-05
177+
178+
dim = len(self.shape_of(x))
179+
return self.block_builder.emit(
180+
relax.op.nn.group_norm(
181+
x,
182+
gamma,
183+
beta,
184+
num_groups=num_groups,
185+
channel_axis=1,
186+
axes=list(range(2, dim)),
187+
epsilon=eps,
188+
)
189+
)
190+
144191
def _upsample_impl(
145192
self,
146193
x: relax.Expr,
@@ -963,6 +1010,7 @@ def create_convert_map(
9631010
"_adaptive_avg_pool3d.default": self._adaptive_avg_pool3d,
9641011
"_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional,
9651012
"_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training,
1013+
"_native_batch_norm_legit.no_stats": self._batch_norm_legit_no_stats,
9661014
"batch_norm.default": self._batch_norm_legit_no_training,
9671015
"adaptive_avg_pool1d.default": self._adaptive_avg_pool1d,
9681016
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
@@ -988,6 +1036,7 @@ def create_convert_map(
9881036
),
9891037
"group_norm.default": self._group_norm,
9901038
"instance_norm.default": self._instance_norm,
1039+
"native_group_norm.default": self._native_group_norm,
9911040
"layer_norm.default": self._layer_norm,
9921041
"linear.default": self._linear,
9931042
"lstm.input": self._lstm,
@@ -1004,6 +1053,7 @@ def create_convert_map(
10041053
"upsample_bicubic2d.vec": self._upsample_bicubic2d,
10051054
# statistical
10061055
"any.dim": self._any,
1056+
"any.dims": self._any,
10071057
"mean.dim": self._mean,
10081058
"prod.default": self._prod,
10091059
"std.correction": self._std,

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,20 +1514,18 @@ def main(
15141514
x: R.Tensor((10, 10), dtype="float32"), test_elements: R.Tensor((8,), dtype="float32")
15151515
) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
15161516
with R.dataflow():
1517-
lv: R.Tensor((10, 10, 1), dtype="float32") = R.expand_dims(x, axis=[-1])
1518-
lv1: R.Tensor((8,), dtype="float32") = R.reshape(test_elements, R.shape([8]))
1519-
lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1)
1520-
lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1], keepdims=False)
1521-
lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3, R.const(0.0, "float32"))
1522-
gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,)
1517+
lv: R.Tensor((10, 10, 1), dtype="float32") = R.reshape(x, R.shape([10, 10, 1]))
1518+
lv1: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, test_elements)
1519+
lv2: R.Tensor((10, 10), dtype="bool") = R.max(lv1, axis=[-1], keepdims=False)
1520+
gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv2,)
15231521
R.output(gv)
15241522
return gv
15251523

15261524
example_args = (
15271525
torch.randn(10, 10, dtype=torch.float32),
15281526
torch.randn(8, dtype=torch.float32),
15291527
)
1530-
verify_model(IsInModel(), example_args, {}, expected)
1528+
verify_model(IsInModel(), example_args, {}, expected, run_ep_decomposition=True)
15311529

15321530

15331531
def test_div_mode():
@@ -3155,7 +3153,7 @@ def main(
31553153
"w1": model.gn.weight.detach().numpy(),
31563154
"w2": model.gn.bias.detach().numpy(),
31573155
}
3158-
verify_model(model, example_args, binding, expected1)
3156+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
31593157

31603158

31613159
def test_instancenorm2d():
@@ -3200,7 +3198,7 @@ def main(
32003198
"w1": torch.ones(3).detach().numpy(),
32013199
"w2": torch.zeros(3).detach().numpy(),
32023200
}
3203-
verify_model(model, example_args, binding, expected1)
3201+
verify_model(model, example_args, binding, expected1, run_ep_decomposition=True)
32043202

32053203

32063204
def test_layernorm():
@@ -5556,7 +5554,9 @@ def main(
55565554

55575555
example_args = (torch.randn(256, 256, dtype=torch.float32),)
55585556
exported_program = export(Identity(), args=example_args)
5559-
mod = from_exported_program(exported_program, unwrap_unit_return_tuple=True)
5557+
mod = from_exported_program(
5558+
exported_program, unwrap_unit_return_tuple=True, run_ep_decomposition=True
5559+
)
55605560
tvm.ir.assert_structural_equal(mod, Expected)
55615561

55625562

@@ -5586,7 +5586,9 @@ def main(
55865586
torch.randn(256, 256, dtype=torch.float32),
55875587
)
55885588
exported_program = export(Identity(), args=example_args)
5589-
mod = from_exported_program(exported_program, no_bind_return_tuple=True)
5589+
mod = from_exported_program(
5590+
exported_program, no_bind_return_tuple=True, run_ep_decomposition=True
5591+
)
55905592
tvm.ir.assert_structural_equal(mod, Expected)
55915593

55925594

0 commit comments

Comments
 (0)