Skip to content

Commit 00fe3ac

Browse files
pytorchbotlucylqGregoryComer
authored
Dedup delegate blobs in emitter (#14658)
Summary: Previously we deduplicated entire 'BackendDelegate' blobs using the preprocessed blob. If two BackendDelegate fields have different id or compile specs, it would be disregarded. This diff: 1. Only deduplicates the preprocessed blob. BackendDelegate retains its own compile specs, etc. 2. Removes the per-method 'delegate_cache', as we have a program-wide delegate cache. 3. Adds a test to confirm we have one delegate segment but two BackendDelegate references pointing at it. Differential Revision: D83162107 Co-authored-by: lucylq <[email protected]> Co-authored-by: Gregory Comer <[email protected]>
1 parent c18d6e9 commit 00fe3ac

File tree

3 files changed

+73
-25
lines changed

3 files changed

+73
-25
lines changed

exir/emit/_emit_program.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def emit_program(
164164
operators=[],
165165
delegates=[],
166166
operator_cache={},
167-
delegate_cache={},
168167
emit_stacktrace=emit_stacktrace,
169168
emit_mutable_buffer_names=emit_mutable_buffer_names,
170169
)

exir/emit/_emitter.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,6 @@ class _EmitterState:
146146
operators: List[Operator]
147147
delegates: List[BackendDelegate]
148148
operator_cache: Dict[Tuple[str, str], int]
149-
# delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates
150-
delegate_cache: Dict[str, int]
151149
emit_stacktrace: bool
152150
emit_mutable_buffer_names: bool
153151

@@ -1091,7 +1089,7 @@ def _emit_delegate(
10911089
delegate's blob."""
10921090
processed_bytes = lowered_module.processed_bytes
10931091
hashed = hashlib.sha256(processed_bytes).hexdigest()
1094-
delegate_index = self.emitter_state.delegate_cache.get(hashed)
1092+
delegate_index = self.program_state.backend_delegate_data_cache.get(hashed)
10951093
delegate_ret = None
10961094

10971095
if isinstance(self.node.meta["spec"], list):
@@ -1129,28 +1127,20 @@ def _emit_delegate(
11291127
if delegate_index is None:
11301128
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
11311129
# present.
1132-
hashed = hashlib.sha256(processed_bytes).hexdigest()
1133-
data_index: Optional[int] = (
1134-
self.program_state.backend_delegate_data_cache.get(hashed)
1130+
delegate_index = len(self.program_state.backend_delegate_data_cache)
1131+
self.program_state.backend_delegate_data_cache[hashed] = delegate_index
1132+
self.program_state.backend_delegate_data.append(
1133+
BackendDelegateInlineData(data=processed_bytes)
11351134
)
1136-
if data_index is None:
1137-
data_index = len(self.program_state.backend_delegate_data)
1138-
self.program_state.backend_delegate_data_cache[hashed] = data_index
1139-
self.program_state.backend_delegate_data.append(
1140-
BackendDelegateInlineData(data=processed_bytes)
1141-
)
1142-
1143-
backend_delegate = BackendDelegate(
1144-
id=lowered_module.backend_id,
1145-
processed=BackendDelegateDataReference(
1146-
location=DataLocation.INLINE, index=data_index
1147-
),
1148-
compile_specs=lowered_module.compile_specs,
1149-
)
1150-
delegate_index = len(self.emitter_state.delegate_cache)
1151-
self.emitter_state.delegates.append(backend_delegate)
1152-
self.emitter_state.delegate_cache[hashed] = delegate_index
11531135

1136+
backend_delegate = BackendDelegate(
1137+
id=lowered_module.backend_id,
1138+
processed=BackendDelegateDataReference(
1139+
location=DataLocation.INLINE, index=delegate_index
1140+
),
1141+
compile_specs=lowered_module.compile_specs,
1142+
)
1143+
self.emitter_state.delegates.append(backend_delegate)
11541144
# TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
11551145
# function's spec and with default arguments. This requires us to store the function's spec
11561146
# in to_backend()
@@ -1163,7 +1153,12 @@ def _emit_delegate(
11631153
delegate_args.append(elem.id)
11641154

11651155
self.chain.instructions.append(
1166-
Instruction(DelegateCall(delegate_index=delegate_index, args=delegate_args))
1156+
Instruction(
1157+
DelegateCall(
1158+
delegate_index=len(self.emitter_state.delegates) - 1,
1159+
args=delegate_args,
1160+
)
1161+
)
11671162
)
11681163

11691164
return delegate_ret

exir/emit/test/test_emit.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,6 +1770,60 @@ def forward(self, x):
17701770
len(edge_program_manager.executorch_program.backend_delegate_data), 1
17711771
)
17721772

1773+
def test_delegate_deduplicate_with_different_compile_specs(self) -> None:
1774+
class LowerableSubModel(torch.nn.Module):
1775+
def __init__(self):
1776+
super().__init__()
1777+
1778+
def forward(self, x):
1779+
return torch.sin(x)
1780+
1781+
lowered = LowerableSubModel()
1782+
example_input = (torch.ones(1),)
1783+
1784+
lowered_edge = to_edge(export(lowered, example_input))
1785+
1786+
from executorch.exir.backend.compile_spec_schema import CompileSpec
1787+
1788+
compile_specs1 = [CompileSpec("config", b"fast")]
1789+
compile_specs2 = [CompileSpec("config", b"small")]
1790+
lowered_module1 = to_backend(
1791+
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs1
1792+
)
1793+
lowered_module2 = to_backend(
1794+
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs2
1795+
)
1796+
1797+
class CompositeModel(torch.nn.Module):
1798+
def __init__(self):
1799+
super().__init__()
1800+
self.lowerable1 = lowered_module1
1801+
self.lowerable2 = lowered_module2
1802+
1803+
def forward(self, x):
1804+
a = self.lowerable1(x)
1805+
b = self.lowerable2(a)
1806+
return a, b
1807+
1808+
composite_model = CompositeModel()
1809+
model_inputs = (torch.ones(1),)
1810+
edge_prog = to_edge(export(composite_model, model_inputs)).to_executorch()
1811+
1812+
exported_program = edge_prog.exported_program()
1813+
program = emit_program({"method1": exported_program}, False).program
1814+
self.assertEqual(len(program.execution_plan), 1)
1815+
1816+
plan = program.execution_plan[0]
1817+
# Two delegates that point to the same blob.
1818+
self.assertEqual(len(plan.delegates), 2)
1819+
self.assertEqual(plan.delegates[0].processed.index, 0)
1820+
self.assertEqual(plan.delegates[1].processed.index, 0)
1821+
# Compile specs are different.
1822+
self.assertEqual(plan.delegates[0].compile_specs, compile_specs1)
1823+
self.assertEqual(plan.delegates[1].compile_specs, compile_specs2)
1824+
# Only one delegate blob in the backend_delegate_data.
1825+
self.assertEqual(len(program.backend_delegate_data), 1)
1826+
17731827
def test_constant_tagged_mutable_tensors(self) -> None:
17741828
class Net(nn.Module):
17751829
def __init__(self):

0 commit comments

Comments
 (0)