Skip to content

Commit 9325d08

Browse files
committed
More test coverage, cleanup
1 parent 9174ee0 commit 9325d08

File tree

6 files changed

+203
-20
lines changed

6 files changed

+203
-20
lines changed

backends/test/harness/tester.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,12 @@ def __init__(
4141
example_inputs: Tuple[torch.Tensor],
4242
stage_classes: Dict[StageType, Callable] | None = None,
4343
dynamic_shapes: Optional[Tuple[Any]] = None,
44+
training: bool = False,
4445
):
45-
module.eval()
46+
if training:
47+
module.train()
48+
else:
49+
module.eval()
4650

4751
self.stage_classes = stage_classes or Tester.default_stage_classes()
4852
self.original_module = module

backends/xnnpack/_passes/decompose_batch_norm.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,14 @@ def can_decompose_batch_norm(
7272
why(node, f"Channel dimension must be statically known, but was {input_meta.shape[1]}.")
7373
return False
7474

75-
if not is_param_node(exported_program, node.args[1]) or not is_param_node(exported_program, node.args[2]):
75+
if node.args[1] is not None and not is_param_node(exported_program, node.args[1]):
7676
if why:
77-
why(node, "Batch norm affine weight and bias must be static.")
77+
why(node, "Batch norm affine weight must be static.")
78+
return False
79+
80+
if node.args[2] is not None and not is_param_node(exported_program, node.args[2]):
81+
if why:
82+
why(node, "Batch norm affine bias must be static.")
7883
return False
7984

8085
if not is_param_node(exported_program, node.args[3]) or not is_param_node(exported_program, node.args[4]):
@@ -87,6 +92,11 @@ def can_decompose_batch_norm(
8792
why(node, "Batch norm epsilon must be static.")
8893
return False
8994

95+
if node.target == exir_ops.edge.aten.native_batch_norm.default and node.args[5] is not False:
96+
if why:
97+
why(node, "Training batch norm is not supported.")
98+
return False
99+
90100
return True
91101

92102
@staticmethod
@@ -103,11 +113,23 @@ def compute_w_and_b(
103113
"""
104114

105115
# See https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
106-
denom = torch.sqrt(running_var + torch.Tensor([eps]))
107-
weight = gamma / denom
108-
bias = -running_mean * gamma / denom + beta
109116

110-
return weight, bias
117+
# Do the math in double precision and convert back to the original dtype at the
118+
# end. ATen kernels do this math in increased precision for float16. Note that
119+
# all of the parameter dtypes must match, as per the ATen behavior.
120+
121+
# Also note that gamma and beta can be None if affine=False. This is equivalent
122+
# to gamma = 1 and beta = 0.
123+
gamma_f64 = gamma.double() if gamma is not None else torch.Tensor([1]).double()
124+
beta_f64 = beta.double() if beta is not None else torch.Tensor([0]).double()
125+
running_mean_f64 = running_mean.double()
126+
running_var_f64 = running_var.double()
127+
128+
denom = torch.sqrt(running_var_f64 + torch.Tensor([eps]))
129+
new_weight = gamma_f64 / denom
130+
new_bias = -running_mean_f64 * gamma_f64 / denom + beta_f64
131+
132+
return new_weight.to(running_mean.dtype), new_bias.to(running_mean.dtype)
111133

112134
def replace_bn_node_with_conv(
113135
self,

backends/xnnpack/partition/config/node_configs.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
XNNPartitionerConfig,
1717
)
1818
from executorch.backends.xnnpack.utils.utils import is_param_node
19-
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
20-
format_target_name,
21-
)
2219
from executorch.exir.backend.utils import WhyNoPartition
2320
from torch.export import ExportedProgram
2421

backends/xnnpack/test/ops/test_batch_norm.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ def get_inputs(self):
5050
class BatchNorm2d(torch.nn.Module):
5151
"""BatchNorm2d with NCHW input (batch, channels, height, width)."""
5252

53-
def __init__(self, num_features: int, dtype: torch.dtype = torch.float):
53+
def __init__(self, num_features: int, dtype: torch.dtype = torch.float, affine: bool = True):
5454
super().__init__()
5555
self.num_features = num_features
5656
self.dtype = dtype
57-
self.bn = torch.nn.BatchNorm2d(num_features).to(dtype)
57+
self.bn = torch.nn.BatchNorm2d(num_features, affine=affine).to(dtype)
5858

5959
def forward(self, x):
6060
return self.bn(x)
@@ -154,6 +154,28 @@ def test_fp16_batch_norm_nchw(self):
154154
"""Test BatchNorm2d with fp16 NCHW input is lowered to XNNPACK."""
155155
self._test_batch_norm(self.BatchNorm2d(num_features=3, dtype=torch.float16))
156156

157+
def test_fp32_batch_norm_nchw_non_affine(self):
158+
"""Test non-affine BatchNorm2d with NCHW input is lowered to XNNPACK."""
159+
self._test_batch_norm(self.BatchNorm2d(num_features=3, affine=False))
160+
161+
class BatchNorm2dChannelsLast(torch.nn.Module):
162+
"""BatchNorm2d with channels_last memory format input."""
163+
164+
def __init__(self, num_features: int):
165+
super().__init__()
166+
self.num_features = num_features
167+
self.bn = torch.nn.BatchNorm2d(num_features)
168+
169+
def forward(self, x):
170+
return self.bn(x)
171+
172+
def get_inputs(self):
173+
return (torch.randn(2, self.num_features, 4, 4).to(memory_format=torch.channels_last),)
174+
175+
def test_fp32_batch_norm_nchw_channels_last(self):
176+
"""Test BatchNorm2d with channels_last memory format input is lowered to XNNPACK."""
177+
self._test_batch_norm(self.BatchNorm2dChannelsLast(num_features=3))
178+
157179
class BatchNorm3d(torch.nn.Module):
158180
"""BatchNorm3d with NCDHW input (batch, channels, depth, height, width)."""
159181

@@ -277,3 +299,63 @@ def test_fp32_conv2d_batch_norm_fused(self):
277299
.serialize()
278300
.run_method_and_compare_outputs()
279301
)
302+
303+
class Conv2dBatchNormChannelsLast(torch.nn.Module):
304+
"""Conv2d followed by BatchNorm (fuseable pattern) with channels_last input."""
305+
306+
def __init__(self, in_channels: int, out_channels: int):
307+
super().__init__()
308+
self.in_channels = in_channels
309+
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
310+
self.bn = randomize_bn(out_channels)
311+
312+
def forward(self, x):
313+
x = self.conv(x)
314+
x = self.bn(x)
315+
return x
316+
317+
def get_inputs(self):
318+
return (torch.randn(2, self.in_channels, 8, 8).to(memory_format=torch.channels_last),)
319+
320+
def test_fp32_conv2d_batch_norm_fused_channels_last(self):
321+
"""
322+
Test Conv2d + BatchNorm with channels_last input where the BatchNorm is
323+
fused into the Conv2d.
324+
"""
325+
model = self.Conv2dBatchNormChannelsLast(in_channels=3, out_channels=8)
326+
model.eval()
327+
328+
(
329+
Tester(model, model.get_inputs())
330+
.export()
331+
.to_edge_transform_and_lower()
332+
# BatchNorm should be fused into conv (not present in the graph)
333+
.check_not(
334+
[
335+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default"
336+
]
337+
)
338+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
339+
.to_executorch()
340+
.serialize()
341+
.run_method_and_compare_outputs()
342+
)
343+
344+
def test_training_bn_not_partitioned(self):
345+
"""Test that training mode BatchNorm is not partitioned."""
346+
model = self.BatchNorm2d(num_features=3)
347+
for _ in range(5):
348+
model(*model.get_inputs())
349+
350+
(
351+
Tester(model, model.get_inputs(), training=True)
352+
.export()
353+
.to_edge_transform_and_lower()
354+
.check(
355+
[
356+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_functional"
357+
]
358+
)
359+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
360+
.run_method_and_compare_outputs()
361+
)

backends/xnnpack/test/passes/test_decompose_batch_norm.py

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import torch
1010
from executorch.backends.xnnpack._passes.decompose_batch_norm import DecomposeBatchNorm
1111
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
12+
from executorch.exir import EdgeProgramManager
13+
from executorch.exir.dialects._ops import ops as exir_ops
1214

1315

1416
class TestDecomposeBatchNorm(unittest.TestCase):
@@ -46,9 +48,9 @@ def forward(self, x):
4648
class BatchNorm2d(torch.nn.Module):
4749
"""Simple BatchNorm2d module with NCHW input."""
4850

49-
def __init__(self, num_features: int):
51+
def __init__(self, num_features: int, affine: bool = True):
5052
super().__init__()
51-
self.bn = torch.nn.BatchNorm2d(num_features)
53+
self.bn = torch.nn.BatchNorm2d(num_features, affine=affine)
5254
# Run a forward pass to update the BN running stats.
5355
self.forward(torch.randn(2, num_features, 4, 4) * 2 + 2)
5456

@@ -57,9 +59,10 @@ def forward(self, x):
5759

5860
def test_fp32_batch_norm_nc(self):
5961
"""Test that BatchNorm1d with NC input is decomposed to convolution."""
60-
(
62+
model = self.BatchNorm1dNC(3).eval()
63+
tester = (
6164
Tester(
62-
self.BatchNorm1dNC(3).eval(),
65+
model,
6366
(torch.randn(2, 3),),
6467
)
6568
.export()
@@ -70,12 +73,14 @@ def test_fp32_batch_norm_nc(self):
7073
.check_not([self.bn_name])
7174
.run_method_and_compare_outputs()
7275
)
76+
self._validate_decomposition(tester.get_artifact(), torch.float32, 3, 1)
7377

7478
def test_fp32_batch_norm_ncl(self):
7579
"""Test that BatchNorm1d with NCL input is decomposed to convolution."""
76-
(
80+
model = self.BatchNorm1dNCL(3).eval()
81+
tester = (
7782
Tester(
78-
self.BatchNorm1dNCL(3).eval(),
83+
model,
7984
(torch.randn(2, 3, 4),),
8085
)
8186
.export()
@@ -86,12 +91,50 @@ def test_fp32_batch_norm_ncl(self):
8691
.check_not([self.bn_name])
8792
.run_method_and_compare_outputs()
8893
)
94+
self._validate_decomposition(tester.get_artifact(), torch.float32, 3, 1)
8995

9096
def test_fp32_batch_norm_nchw(self):
9197
"""Test that BatchNorm2d with NCHW input is decomposed to convolution."""
92-
(
98+
model = self.BatchNorm2d(3).eval()
99+
tester = (
100+
Tester(
101+
model,
102+
(torch.randn(2, 3, 4, 4),),
103+
)
104+
.export()
105+
.to_edge()
106+
.check_count({self.bn_name: 1})
107+
.run_passes(self.PassStage)
108+
.check_count({self.conv_name: 1})
109+
.check_not([self.bn_name])
110+
.run_method_and_compare_outputs()
111+
)
112+
self._validate_decomposition(tester.get_artifact(), torch.float32, 3, 2)
113+
114+
def test_fp16_batch_norm_nchw(self):
115+
"""Test that BatchNorm2d with NCHW input is decomposed to convolution."""
116+
model = self.BatchNorm2d(3).to(torch.float16).eval()
117+
tester = (
93118
Tester(
94-
self.BatchNorm2d(3).eval(),
119+
model,
120+
(torch.randn(2, 3, 4, 4, dtype=torch.float16),),
121+
)
122+
.export()
123+
.to_edge()
124+
.check_count({self.bn_name: 1})
125+
.run_passes(self.PassStage)
126+
.check_count({self.conv_name: 1})
127+
.check_not([self.bn_name])
128+
.run_method_and_compare_outputs()
129+
)
130+
self._validate_decomposition(tester.get_artifact(), torch.float16, 3, 2)
131+
132+
def test_fp32_batch_norm_nchw_non_affine(self):
133+
"""Test that non-affine BatchNorm2d with NCHW input is decomposed to convolution."""
134+
model = self.BatchNorm2d(3, affine=False).eval()
135+
tester = (
136+
Tester(
137+
model,
95138
(torch.randn(2, 3, 4, 4),),
96139
)
97140
.export()
@@ -102,3 +145,36 @@ def test_fp32_batch_norm_nchw(self):
102145
.check_not([self.bn_name])
103146
.run_method_and_compare_outputs()
104147
)
148+
self._validate_decomposition(tester.get_artifact(), torch.float32, 3, 2)
149+
150+
def _validate_decomposition(self, edge_manager: EdgeProgramManager, dtype: torch.dtype, num_channels: int, spatial_dims: int):
151+
# Verify that the graph contains a 1x1 depthwise convolution and that
152+
# the transformed parameter dtypes match the original.
153+
154+
conv_node = next(
155+
n
156+
for n in edge_manager.exported_program().graph.nodes
157+
if n.target == exir_ops.edge.aten.convolution.default
158+
)
159+
self.assertEqual(conv_node.meta["val"].dtype, dtype)
160+
161+
self.assertEqual(len(conv_node.args), 9)
162+
_, w_node, b_node, stride, padding, dilation, transposed, output_padding, groups = conv_node.args
163+
164+
# Check the convolution parameters. It should be 1x1 depthwise convolution.
165+
self.assertEqual(stride, [1] * spatial_dims)
166+
self.assertEqual(padding, [0] * spatial_dims)
167+
self.assertEqual(dilation, [1] * spatial_dims)
168+
self.assertEqual(transposed, False)
169+
self.assertEqual(output_padding, [0] * spatial_dims)
170+
self.assertEqual(groups, num_channels)
171+
172+
w_meta = w_node.meta["val"]
173+
b_meta = b_node.meta["val"]
174+
175+
# Weight should be (out_c, in_c/g, kH, [kW])
176+
# Bias should be (out_c)
177+
self.assertEqual(w_meta.shape, tuple([num_channels, 1] + [1] * spatial_dims))
178+
self.assertEqual(w_meta.dtype, dtype)
179+
self.assertEqual(b_meta.shape, (num_channels,))
180+
self.assertEqual(b_meta.dtype, dtype)

backends/xnnpack/test/tester/tester.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def __init__(
107107
module: torch.nn.Module,
108108
example_inputs: Tuple[torch.Tensor],
109109
dynamic_shapes: Optional[Tuple[Any]] = None,
110+
**kwargs,
110111
):
111112
# Specialize for XNNPACK
112113
stage_classes = (
@@ -127,4 +128,5 @@ def __init__(
127128
stage_classes=stage_classes,
128129
example_inputs=example_inputs,
129130
dynamic_shapes=dynamic_shapes,
131+
**kwargs,
130132
)

0 commit comments

Comments
 (0)