Skip to content

Commit 08f095e

Browse files
gmagogsfmfacebook-github-bot
authored andcommitted
executorch/exir/program/test (#7397)
Summary: Pull Request resolved: #7397 Reviewed By: avikchaudhuri, ydwu4 Differential Revision: D67383235
1 parent c337bef commit 08f095e

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

exir/program/test/test_fake_program.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ def forward(self, arg) -> torch.Tensor:
3030

3131
linear = Linear()
3232
exported_program = export(
33-
linear,
34-
args=(torch.randn(10, 10),),
33+
linear, args=(torch.randn(10, 10),), strict=True
3534
).run_decompositions()
3635
return exported_program
3736

exir/program/test/test_program.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
166166
torch.ones(1),
167167
torch.zeros(1),
168168
),
169+
strict=True,
169170
).run_decompositions()
170-
programs["foo"] = export(
171-
foo,
172-
(torch.ones(1),),
173-
).run_decompositions()
171+
programs["foo"] = export(foo, (torch.ones(1),), strict=True).run_decompositions()
174172
return programs
175173

176174

@@ -289,7 +287,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
289287
return x * 3.14
290288

291289
mul = Mul()
292-
ep = to_edge(torch.export.export(mul, (torch.ones(1),))).exported_program()
290+
ep = to_edge(
291+
torch.export.export(mul, (torch.ones(1),), strict=True)
292+
).exported_program()
293293
for node in ep.graph.nodes:
294294
self.assertNotEqual(node.op, "get_attr")
295295
self.assertEqual(
@@ -306,7 +306,7 @@ def forward(self, x, y):
306306
torch._check(z < 4)
307307
return x[z : z + y.shape[0]]
308308

309-
ep = torch.export.export(M(), (torch.randn(10), torch.tensor([3])))
309+
ep = torch.export.export(M(), (torch.randn(10), torch.tensor([3])), strict=True)
310310

311311
edge_manager = to_edge(
312312
ep, compile_config=exir.EdgeCompileConfig(_check_ir_validity=False)
@@ -350,7 +350,6 @@ def test_edge_manager_transform(self):
350350
)
351351

352352
def test_issue_3659(self):
353-
354353
class Mul(torch.nn.Module):
355354
def __init__(self):
356355
super(Mul, self).__init__()
@@ -371,7 +370,10 @@ def get_dynamic_shapes(self):
371370

372371
model = Mul()
373372
ep = torch.export.export(
374-
model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes()
373+
model,
374+
model.get_example_inputs(),
375+
dynamic_shapes=model.get_dynamic_shapes(),
376+
strict=True,
375377
)
376378

377379
to_edge(
@@ -549,7 +551,7 @@ def _test_edge_dialect_verifier(
549551
if not isinstance(callable, torch.nn.Module):
550552
callable = WrapperModule(callable)
551553

552-
exported_foo = export(callable, inputs)
554+
exported_foo = export(callable, inputs, strict=True)
553555
_ = to_edge(exported_foo, compile_config=edge_compile_config)
554556

555557
def test_edge_dialect_custom_op(self):
@@ -697,7 +699,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
697699
from torch._export.verifier import SpecViolationError
698700

699701
input = torch.arange(9, dtype=torch.float) - 4
700-
ep = torch.export.export(LinalgNorm(), (input,))
702+
ep = torch.export.export(LinalgNorm(), (input,), strict=True)
701703

702704
# aten::linalg_norm is not a core op, so it should error out
703705
with self.assertRaises(SpecViolationError):
@@ -744,7 +746,7 @@ def count_nodes(graph_module, target):
744746

745747
def test_to_edge_with_single_preserved_op(self):
746748
model = TestLinear()
747-
program = torch.export.export(model, model._get_random_inputs())
749+
program = torch.export.export(model, model._get_random_inputs(), strict=True)
748750

749751
ops_not_to_decompose = [
750752
torch.ops.aten.linear.default,
@@ -759,7 +761,7 @@ def test_to_edge_with_single_preserved_op(self):
759761

760762
def test_to_edge_with_partial_ops_preserved(self):
761763
model = TestLinearSDPACombined()
762-
program = torch.export.export(model, model._get_random_inputs())
764+
program = torch.export.export(model, model._get_random_inputs(), strict=True)
763765

764766
ops_not_to_decompose = [
765767
torch.ops.aten.linear.default,
@@ -774,7 +776,7 @@ def test_to_edge_with_partial_ops_preserved(self):
774776

775777
def test_to_edge_with_multiple_ops_preserved(self):
776778
model = TestLinearSDPACombined()
777-
program = torch.export.export(model, model._get_random_inputs())
779+
program = torch.export.export(model, model._get_random_inputs(), strict=True)
778780

779781
ops_not_to_decompose = [
780782
torch.ops.aten.linear.default,
@@ -791,7 +793,7 @@ def test_to_edge_with_multiple_ops_preserved(self):
791793

792794
def test_to_edge_with_preserved_ops_not_in_model(self):
793795
model = TestSDPA()
794-
program = torch.export.export(model, model._get_random_inputs())
796+
program = torch.export.export(model, model._get_random_inputs(), strict=True)
795797

796798
ops_not_to_decompose = [
797799
torch.ops.aten.linear.default,
@@ -806,7 +808,7 @@ def test_to_edge_with_preserved_ops_not_in_model(self):
806808

807809
def test_save_fails(self):
808810
model = TestLinear()
809-
program = torch.export.export(model, model._get_random_inputs())
811+
program = torch.export.export(model, model._get_random_inputs(), strict=True)
810812
edge = to_edge(program)
811813
et = edge.to_executorch()
812814
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)