Skip to content

Commit da7b76d

Browse files
angelayifacebook-github-bot
authored andcommitted
Update verifier
Summary: Fixes #7998 Differential Revision: D68839524
1 parent c5fea7e commit da7b76d

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

exir/program/test/test_program.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,45 @@ def forward(self, x, y):
313313
)
314314
edge_manager.to_executorch()
315315

316+
def test_data_dependent(self):
317+
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
318+
torch.library.define(
319+
"mylib::foo1",
320+
"(Tensor a, Tensor b) -> Tensor",
321+
tags=torch.Tag.pt2_compliant_tag,
322+
lib=lib,
323+
)
324+
325+
@torch.library.impl("mylib::foo1", "cpu", lib=lib)
326+
def foo_impl(a, b):
327+
return a + b
328+
329+
@torch.library.register_fake("mylib::foo1", lib=lib)
330+
def mylib_foo_default_fake(*args, **kwargs):
331+
ctx = torch.library.get_ctx()
332+
fake_shape = ctx.new_dynamic_size()
333+
return torch.empty(fake_shape, dtype=torch.float32, device="cpu")
334+
335+
class M(torch.nn.Module):
336+
def forward(self, a, b, c):
337+
res = torch.ops.mylib.foo1(a, b)
338+
339+
c_item = c.item()
340+
torch._check_is_size(c_item)
341+
torch._check(c_item < res.shape[0])
342+
return res[:c_item]
343+
344+
inp = (torch.randn(10), torch.randn(10), torch.tensor(3))
345+
346+
ep = export(M(), inp)
347+
edge = to_edge(ep)
348+
self.assertTrue(torch.allclose(
349+
edge.exported_program().module()(
350+
*inp
351+
),
352+
M()(*inp),
353+
))
354+
316355
def test_edge_manager_transform(self):
317356
edge_manager: EdgeProgramManager = to_edge(
318357
get_exported_programs(), get_config_methods()

exir/verification/verifier.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
EdgeOpArgValidator,
2121
RunHigherOrderOperatorError,
2222
)
23+
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
24+
from executorch.exir.passes.replace_aten_with_edge_pass import DISALLOW_LIST
2325

2426
from torch._dispatch.python import enable_python_dispatcher
2527
from torch._export.utils import _detect_fake_mode_from_gm
@@ -107,8 +109,7 @@ def _get_exception_list(self) -> List[torch._ops.OpOverload]:
107109
torch.ops.aten.max.default, # TODO(T188268054)
108110
torch.ops.aten.min.default, # TODO(T188268054)
109111
torch.ops.aten.full_like.default, # TODO(T183507359)
110-
]
111-
exception_list += self._exception_list
112+
] + list(_EXECUTORCH_SYM_OPS) + DISALLOW_LIST+ self._exception_list
112113

113114
return exception_list
114115

@@ -251,11 +252,12 @@ def check_valid_edge_op(self, op):
251252
op
252253
in [
253254
operator.getitem,
254-
torch.ops.aten.sym_size.int,
255255
torch.ops.aten.scalar_tensor.default,
256256
torch.ops.aten._assert_async.msg,
257257
torch.ops.aten._assert_scalar.default,
258258
]
259+
+ DISALLOW_LIST
260+
+ list(_EXECUTORCH_SYM_OPS)
259261
+ self._exception_list
260262
):
261263
return

0 commit comments

Comments
 (0)