Skip to content

Commit 8e2ed67

Browse files
tarun292facebook-github-bot
authored andcommitted
Fix emitter bug which doesn't allow input lists to delegates (#282)
Summary: Pull Request resolved: #282 As the delegate op doesn't have a schema we pass in `None` here when emitting the delegate arguments. https://www.internalfb.com/code/fbsource/[4c61670ae0bd654ca630c2288f685d3a4dfa8ece]/fbcode/executorch/exir/emit/_emitter.py?lines=954 `emit_list` expects a `val_type` to be passed in though and triggers an assert when `None` is passed in. https://www.internalfb.com/code/fbsource/[4c61670ae0bd654ca630c2288f685d3a4dfa8ece]/fbcode/executorch/exir/emit/_emitter.py?lines=471 To handle this we try to infer the `JitType` for lists using the newly added function `_get_list_jit_type`. Reviewed By: cccclai Differential Revision: D49180438 fbshipit-source-id: 44d6ab1fdaecdd74ec784d2b3e8e7b15fb63c9d4
1 parent 2f21fe6 commit 8e2ed67

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

exir/emit/_emitter.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,30 @@ def _get_buffer_idx(spec: TensorSpec, program_state: _ProgramState) -> int:
448448
# For constant tensors, allocation_info = None.
449449
return EValue(make_tensor_value(buffer_idx, None, spec))
450450

451+
def _get_list_jit_type(self, val: List[_Argument]) -> _SchemaType:
452+
"""Returns the JIT type for the given python type."""
453+
assert isinstance(
454+
val, list
455+
), f"Input to _get_list_jit_type was expected to be an instance of list but received {type(val)}"
456+
is_tensor_type = all(
457+
isinstance(v, _AbstractValue) and v.tensor is not None for v in val
458+
)
459+
if is_tensor_type:
460+
return torch.TensorType.get()
461+
elif isinstance(val[0], int):
462+
return torch.IntType.get()
463+
elif isinstance(val[0], bool):
464+
return torch.BoolType.get()
465+
elif isinstance(val[0], float):
466+
return torch.FloatType.get()
467+
468+
raise InternalError(
469+
self._emit_node_specific_error(
470+
self.node,
471+
"Couldn't determine JitType for list of elements. Only supports int, float, bool, and Tensor.",
472+
)
473+
)
474+
451475
def _constant_to_evalue( # noqa: C901
452476
self,
453477
val: _Argument,
@@ -465,6 +489,8 @@ def _constant_to_evalue( # noqa: C901
465489
if isinstance(val, list):
466490
# Refine Optional[List[T]] -> List[T] This works because if the val was None it would
467491
# have converted to Null before this function call.
492+
if val_type is None:
493+
val_type = torch.ListType(self._get_list_jit_type(val)) # pyre-ignore
468494
if type(val_type) == torch.OptionalType:
469495
val_type = val_type.getElementType()
470496
assert type(val_type) == torch.ListType

exir/emit/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ python_unittest(
1313
"//executorch/exir:lib",
1414
"//executorch/exir:print_program",
1515
"//executorch/exir:schema",
16+
"//executorch/exir/backend:backend_api",
1617
"//executorch/exir/emit:lib",
1718
"//executorch/exir/passes:const_prop_pass",
1819
"//executorch/exir/tests:lib",

exir/emit/test/test_emit.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import executorch.exir.tests.models as models
1717
import torch
1818
from executorch.exir import CaptureConfig, EdgeCompileConfig, ExecutorchProgram
19+
from executorch.exir.backend.backend_api import to_backend
20+
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
1921
from executorch.exir.emit import emit_program # noqa
2022
from executorch.exir.error import InternalError
2123
from executorch.exir.passes.const_prop_pass import ConstPropPass
@@ -42,6 +44,7 @@
4244
_load_for_executorch_from_buffer,
4345
)
4446
from functorch.experimental import control_flow
47+
from torch import nn
4548

4649

4750
class TestEmit(unittest.TestCase):
@@ -1197,3 +1200,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
11971200
idx,
11981201
node.meta.get("debug_handle"),
11991202
)
1203+
1204+
def test_delegate_with_input_list(self) -> None:
1205+
class BackendWithCompilerDemo(BackendDetails):
1206+
@staticmethod
1207+
def preprocess(
1208+
edge_program,
1209+
compile_specs,
1210+
) -> bytes:
1211+
return PreprocessResult(
1212+
processed_bytes=bytes(str("test"), encoding="utf8"),
1213+
debug_handle_map=None,
1214+
)
1215+
1216+
class TestModel(nn.Module):
1217+
def __init__(self):
1218+
super(TestModel, self).__init__()
1219+
1220+
def forward(self, x):
1221+
return torch.cat(x)
1222+
1223+
inputs = ([torch.ones(2, 2), torch.ones(2, 2)],)
1224+
model = TestModel()
1225+
edgeir_m = exir.capture(model, inputs, exir.CaptureConfig()).to_edge(
1226+
exir.EdgeCompileConfig(_check_ir_validity=False)
1227+
)
1228+
lowered_module = to_backend(
1229+
"BackendWithCompilerDemo", edgeir_m.exported_program, None
1230+
)
1231+
1232+
class CompositeModule(torch.nn.Module):
1233+
def __init__(self):
1234+
super().__init__()
1235+
self.lowered_module = lowered_module
1236+
1237+
def forward(self, list_a):
1238+
return self.lowered_module(list_a)
1239+
1240+
composite_model = CompositeModule()
1241+
exec_prog = (
1242+
exir.capture(composite_model, inputs, exir.CaptureConfig())
1243+
.to_edge()
1244+
.to_executorch()
1245+
)
1246+
exec_prog.buffer

0 commit comments

Comments
 (0)