Skip to content

Commit fd5f946

Browse files
authored
Dedup delegate blobs in emitter
Differential Revision: D83162107 Pull Request resolved: #14564
1 parent 8e6e320 commit fd5f946

File tree

3 files changed

+73
-25
lines changed

3 files changed

+73
-25
lines changed

exir/emit/_emit_program.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def emit_program(
164164
operators=[],
165165
delegates=[],
166166
operator_cache={},
167-
delegate_cache={},
168167
emit_stacktrace=emit_stacktrace,
169168
emit_mutable_buffer_names=emit_mutable_buffer_names,
170169
)

exir/emit/_emitter.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ class _EmitterState:
147147
operators: List[Operator]
148148
delegates: List[BackendDelegate]
149149
operator_cache: Dict[Tuple[str, str], int]
150-
# delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates
151-
delegate_cache: Dict[str, int]
152150
emit_stacktrace: bool
153151
emit_mutable_buffer_names: bool
154152

@@ -1092,7 +1090,7 @@ def _emit_delegate(
10921090
delegate's blob."""
10931091
processed_bytes = lowered_module.processed_bytes
10941092
hashed = hashlib.sha256(processed_bytes).hexdigest()
1095-
delegate_index = self.emitter_state.delegate_cache.get(hashed)
1093+
delegate_index = self.program_state.backend_delegate_data_cache.get(hashed)
10961094
delegate_ret = None
10971095

10981096
if isinstance(self.node.meta["spec"], list):
@@ -1130,28 +1128,20 @@ def _emit_delegate(
11301128
if delegate_index is None:
11311129
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
11321130
# present.
1133-
hashed = hashlib.sha256(processed_bytes).hexdigest()
1134-
data_index: Optional[int] = (
1135-
self.program_state.backend_delegate_data_cache.get(hashed)
1131+
delegate_index = len(self.program_state.backend_delegate_data_cache)
1132+
self.program_state.backend_delegate_data_cache[hashed] = delegate_index
1133+
self.program_state.backend_delegate_data.append(
1134+
BackendDelegateInlineData(data=processed_bytes)
11361135
)
1137-
if data_index is None:
1138-
data_index = len(self.program_state.backend_delegate_data)
1139-
self.program_state.backend_delegate_data_cache[hashed] = data_index
1140-
self.program_state.backend_delegate_data.append(
1141-
BackendDelegateInlineData(data=processed_bytes)
1142-
)
1143-
1144-
backend_delegate = BackendDelegate(
1145-
id=lowered_module.backend_id,
1146-
processed=BackendDelegateDataReference(
1147-
location=DataLocation.INLINE, index=data_index
1148-
),
1149-
compile_specs=lowered_module.compile_specs,
1150-
)
1151-
delegate_index = len(self.emitter_state.delegate_cache)
1152-
self.emitter_state.delegates.append(backend_delegate)
1153-
self.emitter_state.delegate_cache[hashed] = delegate_index
11541136

1137+
backend_delegate = BackendDelegate(
1138+
id=lowered_module.backend_id,
1139+
processed=BackendDelegateDataReference(
1140+
location=DataLocation.INLINE, index=delegate_index
1141+
),
1142+
compile_specs=lowered_module.compile_specs,
1143+
)
1144+
self.emitter_state.delegates.append(backend_delegate)
11551145
# TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
11561146
# function's spec and with default arguments. This requires us to store the function's spec
11571147
# in to_backend()
@@ -1164,7 +1154,12 @@ def _emit_delegate(
11641154
delegate_args.append(elem.id)
11651155

11661156
self.chain.instructions.append(
1167-
Instruction(DelegateCall(delegate_index=delegate_index, args=delegate_args))
1157+
Instruction(
1158+
DelegateCall(
1159+
delegate_index=len(self.emitter_state.delegates) - 1,
1160+
args=delegate_args,
1161+
)
1162+
)
11681163
)
11691164

11701165
return delegate_ret

exir/emit/test/test_emit.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,6 +1770,60 @@ def forward(self, x):
17701770
len(edge_program_manager.executorch_program.backend_delegate_data), 1
17711771
)
17721772

1773+
def test_delegate_deduplicate_with_different_compile_specs(self) -> None:
1774+
class LowerableSubModel(torch.nn.Module):
1775+
def __init__(self):
1776+
super().__init__()
1777+
1778+
def forward(self, x):
1779+
return torch.sin(x)
1780+
1781+
lowered = LowerableSubModel()
1782+
example_input = (torch.ones(1),)
1783+
1784+
lowered_edge = to_edge(export(lowered, example_input))
1785+
1786+
from executorch.exir.backend.compile_spec_schema import CompileSpec
1787+
1788+
compile_specs1 = [CompileSpec("config", b"fast")]
1789+
compile_specs2 = [CompileSpec("config", b"small")]
1790+
lowered_module1 = to_backend(
1791+
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs1
1792+
)
1793+
lowered_module2 = to_backend(
1794+
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs2
1795+
)
1796+
1797+
class CompositeModel(torch.nn.Module):
1798+
def __init__(self):
1799+
super().__init__()
1800+
self.lowerable1 = lowered_module1
1801+
self.lowerable2 = lowered_module2
1802+
1803+
def forward(self, x):
1804+
a = self.lowerable1(x)
1805+
b = self.lowerable2(a)
1806+
return a, b
1807+
1808+
composite_model = CompositeModel()
1809+
model_inputs = (torch.ones(1),)
1810+
edge_prog = to_edge(export(composite_model, model_inputs)).to_executorch()
1811+
1812+
exported_program = edge_prog.exported_program()
1813+
program = emit_program({"method1": exported_program}, False).program
1814+
self.assertEqual(len(program.execution_plan), 1)
1815+
1816+
plan = program.execution_plan[0]
1817+
# Two delegates that point to the same blob.
1818+
self.assertEqual(len(plan.delegates), 2)
1819+
self.assertEqual(plan.delegates[0].processed.index, 0)
1820+
self.assertEqual(plan.delegates[1].processed.index, 0)
1821+
# Compile specs are different.
1822+
self.assertEqual(plan.delegates[0].compile_specs, compile_specs1)
1823+
self.assertEqual(plan.delegates[1].compile_specs, compile_specs2)
1824+
# Only one delegate blob in the backend_delegate_data.
1825+
self.assertEqual(len(program.backend_delegate_data), 1)
1826+
17731827
def test_constant_tagged_mutable_tensors(self) -> None:
17741828
class Net(nn.Module):
17751829
def __init__(self):

0 commit comments

Comments
 (0)