Skip to content

Commit fea4292

Browse files
angelayifacebook-github-bot
authored andcommitted
Update verifier (#8034)
Summary: Fixes #7998 Reviewed By: JacobSzwejbka Differential Revision: D68839524
1 parent c5fea7e commit fea4292

File tree

2 files changed

+57
-11
lines changed

2 files changed

+57
-11
lines changed

exir/program/test/test_program.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,45 @@ def forward(self, x, y):
313313
)
314314
edge_manager.to_executorch()
315315

316+
def test_data_dependent(self):
317+
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
318+
torch.library.define(
319+
"mylib::foo1",
320+
"(Tensor a, Tensor b) -> Tensor",
321+
tags=torch.Tag.pt2_compliant_tag,
322+
lib=lib,
323+
)
324+
325+
@torch.library.impl("mylib::foo1", "cpu", lib=lib)
326+
def foo_impl(a, b):
327+
return a + b
328+
329+
@torch.library.register_fake("mylib::foo1", lib=lib)
330+
def mylib_foo_default_fake(*args, **kwargs):
331+
ctx = torch.library.get_ctx()
332+
fake_shape = ctx.new_dynamic_size()
333+
return torch.empty(fake_shape, dtype=torch.float32, device="cpu")
334+
335+
class M(torch.nn.Module):
336+
def forward(self, a, b, c):
337+
res = torch.ops.mylib.foo1(a, b)
338+
339+
c_item = c.item()
340+
torch._check_is_size(c_item)
341+
torch._check(c_item < res.shape[0])
342+
return res[:c_item]
343+
344+
inp = (torch.randn(10), torch.randn(10), torch.tensor(3))
345+
346+
ep = export(M(), inp)
347+
edge = to_edge(ep)
348+
self.assertTrue(
349+
torch.allclose(
350+
edge.exported_program().module()(*inp),
351+
M()(*inp),
352+
)
353+
)
354+
316355
def test_edge_manager_transform(self):
317356
edge_manager: EdgeProgramManager = to_edge(
318357
get_exported_programs(), get_config_methods()

exir/verification/verifier.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from executorch.exir.error import ExportError, ExportErrorType
1717
from executorch.exir.lowered_backend_module import LoweredBackendModule
1818
from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap
19+
from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS
20+
from executorch.exir.passes.replace_aten_with_edge_pass import DISALLOW_LIST
1921
from executorch.exir.verification.arg_validator import (
2022
EdgeOpArgValidator,
2123
RunHigherOrderOperatorError,
@@ -99,16 +101,20 @@ def __init__(self) -> None:
99101
self._exception_list = exception_list if exception_list else []
100102

101103
def _get_exception_list(self) -> List[torch._ops.OpOverload]:
102-
exception_list = [
103-
torch.ops.aten.mkldnn_rnn_layer.default,
104-
torch.ops.aten._upsample_bilinear2d_aa.default,
105-
torch.ops.aten.quantize_per_tensor.default,
106-
torch.ops.aten.dequantize.self,
107-
torch.ops.aten.max.default, # TODO(T188268054)
108-
torch.ops.aten.min.default, # TODO(T188268054)
109-
torch.ops.aten.full_like.default, # TODO(T183507359)
110-
]
111-
exception_list += self._exception_list
104+
exception_list = (
105+
[
106+
torch.ops.aten.mkldnn_rnn_layer.default,
107+
torch.ops.aten._upsample_bilinear2d_aa.default,
108+
torch.ops.aten.quantize_per_tensor.default,
109+
torch.ops.aten.dequantize.self,
110+
torch.ops.aten.max.default, # TODO(T188268054)
111+
torch.ops.aten.min.default, # TODO(T188268054)
112+
torch.ops.aten.full_like.default, # TODO(T183507359)
113+
]
114+
+ list(_EXECUTORCH_SYM_OPS)
115+
+ DISALLOW_LIST
116+
+ self._exception_list
117+
)
112118

113119
return exception_list
114120

@@ -251,11 +257,12 @@ def check_valid_edge_op(self, op):
251257
op
252258
in [
253259
operator.getitem,
254-
torch.ops.aten.sym_size.int,
255260
torch.ops.aten.scalar_tensor.default,
256261
torch.ops.aten._assert_async.msg,
257262
torch.ops.aten._assert_scalar.default,
258263
]
264+
+ DISALLOW_LIST
265+
+ list(_EXECUTORCH_SYM_OPS)
259266
+ self._exception_list
260267
):
261268
return

0 commit comments

Comments
 (0)