Skip to content

Commit b4738be

Browse files
lucylqfacebook-github-bot
authored andcommitted
Dedup delegate blobs in emitter (pytorch#14564)
Summary: Previously we deduplicated entire 'BackendDelegate' blobs using the preprocessed blob. If two BackendDelegate fields have different id or compile specs, but the same preprocessed blob, we would take the first one and use it in the execution plan. The id/compile specs of the second would be lost. This diff: 1. Only deduplicates the preprocessed blob. BackendDelegate retains its own compile specs, etc. 2. Removes the per-method 'delegate_cache', as we have a program-wide delegate cache. 3. Adds a test to confirm we have one delegate segment but two BackendDelegate references pointing at it. Reviewed By: JacobSzwejbka Differential Revision: D83162107
1 parent 684b5fd commit b4738be

File tree

3 files changed

+77
-25
lines changed

3 files changed

+77
-25
lines changed

exir/emit/_emit_program.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def emit_program(
164164
operators=[],
165165
delegates=[],
166166
operator_cache={},
167-
delegate_cache={},
168167
emit_stacktrace=emit_stacktrace,
169168
emit_mutable_buffer_names=emit_mutable_buffer_names,
170169
)

exir/emit/_emitter.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ class _EmitterState:
147147
operators: List[Operator]
148148
delegates: List[BackendDelegate]
149149
operator_cache: Dict[Tuple[str, str], int]
150-
# delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates
151-
delegate_cache: Dict[str, int]
152150
emit_stacktrace: bool
153151
emit_mutable_buffer_names: bool
154152

@@ -1092,7 +1090,7 @@ def _emit_delegate(
10921090
delegate's blob."""
10931091
processed_bytes = lowered_module.processed_bytes
10941092
hashed = hashlib.sha256(processed_bytes).hexdigest()
1095-
delegate_index = self.emitter_state.delegate_cache.get(hashed)
1093+
delegate_index = self.program_state.backend_delegate_data_cache.get(hashed)
10961094
delegate_ret = None
10971095

10981096
if isinstance(self.node.meta["spec"], list):
@@ -1130,28 +1128,20 @@ def _emit_delegate(
11301128
if delegate_index is None:
11311129
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
11321130
# present.
1133-
hashed = hashlib.sha256(processed_bytes).hexdigest()
1134-
data_index: Optional[int] = (
1135-
self.program_state.backend_delegate_data_cache.get(hashed)
1131+
delegate_index = len(self.program_state.backend_delegate_data_cache)
1132+
self.program_state.backend_delegate_data_cache[hashed] = delegate_index
1133+
self.program_state.backend_delegate_data.append(
1134+
BackendDelegateInlineData(data=processed_bytes)
11361135
)
1137-
if data_index is None:
1138-
data_index = len(self.program_state.backend_delegate_data)
1139-
self.program_state.backend_delegate_data_cache[hashed] = data_index
1140-
self.program_state.backend_delegate_data.append(
1141-
BackendDelegateInlineData(data=processed_bytes)
1142-
)
1143-
1144-
backend_delegate = BackendDelegate(
1145-
id=lowered_module.backend_id,
1146-
processed=BackendDelegateDataReference(
1147-
location=DataLocation.INLINE, index=data_index
1148-
),
1149-
compile_specs=lowered_module.compile_specs,
1150-
)
1151-
delegate_index = len(self.emitter_state.delegate_cache)
1152-
self.emitter_state.delegates.append(backend_delegate)
1153-
self.emitter_state.delegate_cache[hashed] = delegate_index
11541136

1137+
backend_delegate = BackendDelegate(
1138+
id=lowered_module.backend_id,
1139+
processed=BackendDelegateDataReference(
1140+
location=DataLocation.INLINE, index=delegate_index
1141+
),
1142+
compile_specs=lowered_module.compile_specs,
1143+
)
1144+
self.emitter_state.delegates.append(backend_delegate)
11551145
# TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
11561146
# function's spec and with default arguments. This requires us to store the function's spec
11571147
# in to_backend()
@@ -1164,7 +1154,12 @@ def _emit_delegate(
11641154
delegate_args.append(elem.id)
11651155

11661156
self.chain.instructions.append(
1167-
Instruction(DelegateCall(delegate_index=delegate_index, args=delegate_args))
1157+
Instruction(
1158+
DelegateCall(
1159+
delegate_index=len(self.emitter_state.delegates) - 1,
1160+
args=delegate_args,
1161+
)
1162+
)
11681163
)
11691164

11701165
return delegate_ret

exir/emit/test/test_emit.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,6 +1770,64 @@ def forward(self, x):
17701770
len(edge_program_manager.executorch_program.backend_delegate_data), 1
17711771
)
17721772

1773+
def test_delegate_deduplicate_with_different_compile_specs(self) -> None:
1774+
class LowerableSubModel(torch.nn.Module):
1775+
def __init__(self):
1776+
super().__init__()
1777+
1778+
def forward(self, x):
1779+
return torch.sin(x)
1780+
1781+
lowered = LowerableSubModel()
1782+
example_input = (torch.ones(1),)
1783+
1784+
lowered_edge = to_edge(export(lowered, example_input))
1785+
1786+
from executorch.exir.backend.compile_spec_schema import CompileSpec
1787+
# noqa: F401
1788+
from executorch.exir.backend.test.backend_with_compiler_demo import (
1789+
BackendWithCompilerDemo, # noqa: F401
1790+
)
1791+
1792+
compile_specs1 = [CompileSpec("config", b"fast")]
1793+
compile_specs2 = [CompileSpec("config", b"small")]
1794+
lowered_module1 = to_backend(
1795+
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs1
1796+
)
1797+
lowered_module2 = to_backend(
1798+
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs2
1799+
)
1800+
1801+
class CompositeModel(torch.nn.Module):
1802+
def __init__(self):
1803+
super().__init__()
1804+
self.lowerable1 = lowered_module1
1805+
self.lowerable2 = lowered_module2
1806+
1807+
def forward(self, x):
1808+
a = self.lowerable1(x)
1809+
b = self.lowerable2(a)
1810+
return a, b
1811+
1812+
composite_model = CompositeModel()
1813+
model_inputs = (torch.ones(1),)
1814+
edge_prog = to_edge(export(composite_model, model_inputs)).to_executorch()
1815+
1816+
exported_program = edge_prog.exported_program()
1817+
program = emit_program({"method1": exported_program}, False).program
1818+
self.assertEqual(len(program.execution_plan), 1)
1819+
1820+
plan = program.execution_plan[0]
1821+
# Two delegates that point to the same blob.
1822+
self.assertEqual(len(plan.delegates), 2)
1823+
self.assertEqual(plan.delegates[0].processed.index, 0)
1824+
self.assertEqual(plan.delegates[1].processed.index, 0)
1825+
# Compile specs are different.
1826+
self.assertEqual(plan.delegates[0].compile_specs, compile_specs1)
1827+
self.assertEqual(plan.delegates[1].compile_specs, compile_specs2)
1828+
# Only one delegate blob in the backend_delegate_data.
1829+
self.assertEqual(len(program.backend_delegate_data), 1)
1830+
17731831
def test_constant_tagged_mutable_tensors(self) -> None:
17741832
class Net(nn.Module):
17751833
def __init__(self):

0 commit comments

Comments
 (0)