Skip to content

Commit d1cd92f

Browse files
stellaraccidentUbuntu
andauthored
Support the buffer mutation protocol via the FxImporter. (#577)
Taken along with llvm/torch-mlir#3074 and the hack job here (https://gist.github.com/stellaraccident/83f91c7316ea668d59e0718e179e2cfd), this gives us a path to export hermetic training steps from PyTorch. --------- Co-authored-by: Ubuntu <kyle@kyle-mem.judsoscro3wupi0qm4bjlj5m3b.bx.internal.cloudapp.net>
1 parent 59bc67d commit d1cd92f

File tree

6 files changed

+117
-1
lines changed

6 files changed

+117
-1
lines changed

core/shark_turbine/aot/builtins/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@
3636
"export_global",
3737
"export_global_tree",
3838
"export_parameters",
39+
"export_buffers",
3940
"jittable",
4041
]

core/shark_turbine/aot/builtins/globals.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@
2929
)
3030

3131

32+
__all__ = [
33+
"export_global",
34+
"export_global_tree",
35+
"export_parameters",
36+
"export_buffers",
37+
]
38+
39+
3240
class export_global(GlobalsDef, Abstractifiable):
3341
"""Exports a single global into a CompiledModule."""
3442

@@ -164,6 +172,60 @@ def __repr__(self):
164172
return f"<export_parameters {', '.join(names)}>"
165173

166174

175+
class export_buffers(GlobalsDef, TreeAbstractifiable):
176+
"""Exports buffers from an nn.Module.
177+
178+
These are exposed to procedural programs as a dictionary of param/values.
179+
"""
180+
181+
__slots__ = [
182+
"_buffer_list",
183+
"_schema",
184+
"_tree",
185+
]
186+
187+
def __init__(
188+
self,
189+
nn_module: nn.Module,
190+
*,
191+
mutable: Optional[bool] = None,
192+
external: Optional[bool] = None,
193+
external_scope: Optional[str] = None,
194+
name_mapper: Optional[NameMapCallback] = None,
195+
uninitialized: Optional[bool] = None,
196+
attrs: Optional[GlobalAttributes] = None,
197+
):
198+
if attrs is None:
199+
attrs = GlobalAttributes(
200+
mutable=bool(mutable),
201+
external=external,
202+
external_scope=external_scope,
203+
name_mapper=name_mapper,
204+
uninitialized=uninitialized,
205+
)
206+
super().__init__(attrs)
207+
self._buffer_list = list(nn_module.named_buffers())
208+
self._tree = dict(self._buffer_list)
209+
_, self._schema = tree_flatten(self._tree)
210+
211+
def items(self):
212+
for name, value in self._buffer_list:
213+
yield (name, value)
214+
215+
def schema(self) -> TreeSpec:
216+
return self._schema
217+
218+
def abstractify_tree(self):
219+
return tree_map(abstractify_single_value, self._tree)
220+
221+
def __getitem__(self, key):
222+
return self._tree[key]
223+
224+
def __repr__(self):
225+
names = [name for name, _ in self._param_list]
226+
return f"<export_buffers {', '.join(names)}>"
227+
228+
167229
def _transform_tree_to_names(prefix: str, tree):
168230
"""Produces a topologically similar tree but where each value is a fully qualified name."""
169231
join = lambda key: f"{prefix}.{key}" if prefix else key

core/shark_turbine/aot/support/ir_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626

2727
from ...support.ir_imports import (
28+
AsmState,
2829
Attribute,
2930
BF16Type,
3031
DenseElementsAttr,
@@ -197,7 +198,7 @@ def handle_mlir_error(self, op: Operation, e: MLIRError, message: str):
197198
try:
198199
with open(dump_path, "wb") as f:
199200
op.print(
200-
f,
201+
file=f,
201202
binary=True,
202203
print_generic_op_form=True,
203204
large_elements_limit=100,

core/shark_turbine/aot/support/procedural/exported_program.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
FxImporter,
2828
FxImporterHooks,
2929
GraphNodeImporter,
30+
InputInfo,
3031
)
3132

3233
from ....support.logging import aot_logger as logger
@@ -219,6 +220,27 @@ class _Hooks(FxImporterHooks):
219220
def __init__(self, module_builder: ModuleBuilder):
220221
self.module_builder = module_builder
221222

223+
def store_produced_value(
224+
self,
225+
gni: GraphNodeImporter,
226+
py_value: Any,
227+
produced_ir_value: Any,
228+
info: InputInfo,
229+
):
230+
module_builder = self.module_builder
231+
# See if we know about it.
232+
mapping = module_builder.global_ref_tracker.track(py_value)
233+
if mapping.is_empty:
234+
raise ValueError(f"Cannot store value to unmapped global for: {info}")
235+
logger.debug("Resolved global for store %r", mapping)
236+
materialized_global: MaterializedGlobal = mapping.value # type: ignore
237+
converted_value = Operation.create(
238+
"torch_c.to_builtin_tensor",
239+
results=[materialized_global.ir_type],
240+
operands=[produced_ir_value],
241+
).result
242+
util_d.GlobalStoreOp(converted_value, materialized_global.symbol_name)
243+
222244
def resolve_literal(self, gni: GraphNodeImporter, literal: Any) -> Optional[Value]:
223245
# We support resolution of tracked reference types. Currently this
224246
# only includes Tensors. All others we let the importer do what it

core/shark_turbine/support/ir_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""Unifies all imports of iree.compiler.ir into one place."""
99

1010
from iree.compiler.ir import (
11+
AsmState,
1112
Attribute,
1213
Block,
1314
BlockArgument,

core/tests/aot/compiled_exported_program_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,23 @@ class ParamsAsGlobalsModule(CompiledModule):
119119
2, module_str.count("util.global.load @_params.classifier.bias")
120120
)
121121

122+
def testBuffersAsGlobals(self):
123+
fxb = FxProgramsBuilder(SimpleBuffers())
124+
125+
@fxb.export_program(args=(torch.empty([128]),))
126+
def _compute1(module, x):
127+
return module.forward(x)
128+
129+
class BuffersAsGlobalsModule(CompiledModule):
130+
buffers = export_buffers(fxb.root_module, mutable=True)
131+
compute1 = _compute1
132+
133+
inst = BuffersAsGlobalsModule(context=Context(), import_to="import")
134+
module_str = str(CompiledModule.get_mlir_module(inst))
135+
self.assertIn("util.global private mutable @_buffers.buf", module_str)
136+
self.assertIn("%_buffers.buf = util.global.load @_buffers.buf", module_str)
137+
self.assertIn("util.global.store", module_str)
138+
122139

123140
class SimpleParams(nn.Module):
124141
def __init__(self):
@@ -129,6 +146,18 @@ def forward(self, x):
129146
return self.classifier(x)
130147

131148

149+
class SimpleBuffers(nn.Module):
150+
def __init__(self):
151+
super().__init__()
152+
self.register_buffer("buf", torch.randn(1))
153+
154+
def forward(self, x: torch.Tensor):
155+
sumx = (x).sum()
156+
output = x * self.buf
157+
self.buf.copy_(sumx)
158+
return output
159+
160+
132161
if __name__ == "__main__":
133162
logging.basicConfig(level=logging.DEBUG)
134163
unittest.main()

0 commit comments

Comments
 (0)