Skip to content

Commit b6d6645

Browse files
authored
Adding SchemaTrackingMixin (#1109)
* Adding SchemaTrackingMixin * Small fix in test_tensor * Add schema-tracking to Block
1 parent b5dff16 commit b6d6645

File tree

4 files changed

+202
-7
lines changed

4 files changed

+202
-7
lines changed

merlin/models/torch/block.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
from merlin.models.torch.container import BlockContainer, BlockContainerDict
2525
from merlin.models.torch.link import Link, LinkType
2626
from merlin.models.torch.registry import registry
27+
from merlin.models.torch.utils.schema_utils import SchemaTrackingMixin
2728
from merlin.models.utils.registry import RegistryMixin
2829

2930

30-
class Block(BlockContainer, RegistryMixin):
31+
class Block(BlockContainer, SchemaTrackingMixin, RegistryMixin):
3132
"""A base-class that calls it's modules sequentially.
3233
3334
Parameters
@@ -36,12 +37,16 @@ class Block(BlockContainer, RegistryMixin):
3637
Variable length argument list of PyTorch modules to be contained in the block.
3738
name : Optional[str], default = None
3839
The name of the block. If None, no name is assigned.
40+
track_schema : bool, default = True
41+
If True, the schema of the output tensors are tracked.
3942
"""
4043

4144
registry = registry
4245

43-
def __init__(self, *module: nn.Module, name: Optional[str] = None):
46+
def __init__(self, *module: nn.Module, name: Optional[str] = None, track_schema: bool = True):
4447
super().__init__(*module, name=name)
48+
if track_schema:
49+
self._register_schema_tracking_hook()
4550

4651
def forward(
4752
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
@@ -138,17 +143,16 @@ class ParallelBlock(Block):
138143
Variable length argument list of PyTorch modules to be contained in the block.
139144
name : Optional[str], default = None
140145
The name of the block. If None, no name is assigned.
146+
track_schema : bool, default = True
147+
If True, the schema of the output tensors are tracked.
141148
"""
142149

143-
def __init__(
144-
self,
145-
*inputs: Union[nn.Module, Dict[str, nn.Module]],
146-
):
150+
def __init__(self, *inputs: Union[nn.Module, Dict[str, nn.Module]], track_schema: bool = True):
147151
pre = BlockContainer(name="pre")
148152
branches = BlockContainerDict(*inputs)
149153
post = BlockContainer(name="post")
150154

151-
super().__init__()
155+
super().__init__(track_schema=track_schema)
152156

153157
self.pre = pre
154158
self.branches = branches
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import torch
2+
3+
from merlin.schema import ColumnSchema, Schema, Tags
4+
5+
6+
class SchemaTrackingMixin:
7+
"""
8+
A mixin class for PyTorch modules to track the output shapes and dtypes
9+
of the forward pass. This is used in order to automatically generate
10+
the output-schema.
11+
12+
It registers a hook to capture this information and
13+
provides methods to access the output schema, as well as to set the module
14+
in training or evaluation mode.
15+
"""
16+
17+
def __init__(self):
18+
super().__init__()
19+
self._register_schema_tracking_hook()
20+
21+
def _post_forward_hook(self, module, input, output):
22+
"""Hook function to be called after the forward pass of the module.
23+
24+
Parameters
25+
----------
26+
module : torch.nn.Module
27+
The module for which the forward pass was called.
28+
input : tuple
29+
The input arguments passed to the forward method.
30+
output : torch.Tensor or dict
31+
The output of the forward method.
32+
"""
33+
if not module._forward_called:
34+
if isinstance(output, dict):
35+
for key, value in output.items():
36+
module._output_shapes[key] = value.shape
37+
module._output_dtypes[key] = value.dtype
38+
else:
39+
module._output_shapes["output"] = output.shape
40+
module._output_dtypes["output"] = output.dtype
41+
module._forward_called = True
42+
module._handle.remove()
43+
44+
def _register_schema_tracking_hook(self):
45+
"""
46+
Register the post forward hook to the module.
47+
"""
48+
self._forward_called = False
49+
self._handle = None
50+
self._output_shapes = {}
51+
self._output_dtypes = {}
52+
53+
if self._handle is None:
54+
self._handle = self.register_forward_hook(self._post_forward_hook)
55+
56+
def output_schema(self) -> Schema:
57+
"""Get the output schema of the module.
58+
59+
Returns
60+
-------
61+
Schema
62+
The output schema of the module.
63+
64+
Raises
65+
------
66+
RuntimeError
67+
If forward() has not been called before calling this method.
68+
"""
69+
70+
if not hasattr(self, "_output_shapes"):
71+
raise RuntimeError(
72+
"Schema-tracking hook not registered, use `_register_schema_tracking_hook`."
73+
)
74+
75+
if not self._forward_called:
76+
raise RuntimeError("forward() must be called before output_schema() can be called.")
77+
78+
columns = []
79+
80+
for name, shape in self._output_shapes.items():
81+
dtype = self._output_dtypes[name]
82+
dims = (None,) + tuple(shape)
83+
tags = None
84+
85+
if len(shape) > 1 and dtype != torch.int32:
86+
tags = [Tags.EMBEDDING]
87+
88+
columns.append(ColumnSchema(name, dims=dims, tags=tags, dtype=dtype))
89+
90+
return Schema(columns)
91+
92+
def train(self, mode=True):
93+
self._register_schema_tracking_hook()
94+
return super().train(mode)
95+
96+
def eval(self):
97+
self._register_schema_tracking_hook()
98+
return super().eval()

tests/unit/torch/test_block.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from merlin.models.torch.block import Block, ParallelBlock
2525
from merlin.models.torch.container import BlockContainer, BlockContainerDict
2626
from merlin.models.torch.utils import module_utils
27+
from merlin.schema import Tags
2728

2829

2930
class PlusOne(nn.Module):
@@ -50,6 +51,14 @@ def test_identity(self):
5051

5152
assert torch.equal(inputs, outputs)
5253

54+
schema = block.output_schema()
55+
assert schema.first.dtype.name == str(outputs.dtype).split(".")[-1]
56+
57+
def test_no_schema_tracking(self):
58+
block = Block(track_schema=False)
59+
with pytest.raises(RuntimeError, match="Schema-tracking hook not registered"):
60+
block.output_schema()
61+
5362
def test_insertion(self):
5463
block = Block()
5564
block.prepend(PlusOne())
@@ -148,6 +157,20 @@ def test_forward_dict_duplicate(self):
148157
with pytest.raises(RuntimeError):
149158
pb(inputs)
150159

160+
def test_schema_tracking(self):
161+
pb = ParallelBlock({"a": PlusOne(), "b": PlusOne()})
162+
163+
inputs = torch.randn(1, 3)
164+
outputs = pb(inputs)
165+
166+
schema = pb.output_schema()
167+
168+
for name in outputs:
169+
assert name in schema.column_names
170+
assert schema[name].dtype.name == str(outputs[name].dtype).split(".")[-1]
171+
172+
assert len(schema.select_by_tag(Tags.EMBEDDING)) == 2
173+
151174
def test_forward_tuple(self):
152175
inputs = torch.randn(1, 3)
153176
pb = ParallelBlock({"test": PlusOneTuple()})
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
import torch
3+
from torch import nn
4+
5+
from merlin.models.torch.utils.module_utils import module_test
6+
from merlin.models.torch.utils.schema_utils import SchemaTrackingMixin
7+
from merlin.schema import Schema, Tags
8+
9+
10+
class TrackedModule(SchemaTrackingMixin, nn.Module):
11+
def __init__(self):
12+
super().__init__()
13+
self.linear = nn.LazyLinear(10)
14+
15+
def forward(self, x: torch.Tensor):
16+
return self.linear(x)
17+
18+
19+
class TrackedDictModule(SchemaTrackingMixin, nn.Module):
20+
def __init__(self):
21+
super().__init__()
22+
self.linear = nn.LazyLinear(10)
23+
24+
def forward(self, x: torch.Tensor):
25+
return {"a": self.linear(x), "b": self.linear(x)}
26+
27+
28+
class TestSchemaTrackingMixin:
29+
def test_tensor(self):
30+
inputs = torch.randn(1, 5)
31+
tracked_module = TrackedModule()
32+
module_test(tracked_module, inputs)
33+
34+
schema = tracked_module.output_schema()
35+
assert isinstance(schema, Schema)
36+
assert len(schema) == 1
37+
assert len(schema.select_by_tag(Tags.EMBEDDING)) == 1
38+
39+
def test_dict(self):
40+
inputs = torch.randn(1, 5)
41+
tracked_module = TrackedDictModule()
42+
43+
outputs = tracked_module(inputs)
44+
traced_outputs = module_test(tracked_module, inputs)
45+
assert torch.equal(outputs["a"], traced_outputs["a"])
46+
assert torch.equal(outputs["b"], traced_outputs["b"])
47+
48+
schema = tracked_module.output_schema()
49+
assert isinstance(schema, Schema)
50+
assert len(schema) == 2
51+
assert len(schema.select_by_tag(Tags.EMBEDDING)) == 2
52+
53+
def test_exception(self):
54+
tracked_module = TrackedModule()
55+
with pytest.raises(RuntimeError):
56+
tracked_module.output_schema()
57+
58+
def test_train(self):
59+
tracked_module = TrackedModule()
60+
tracked_module(torch.randn(1, 5))
61+
62+
tracked_module.train()
63+
assert not tracked_module._forward_called
64+
65+
def test_eval(self):
66+
tracked_module = TrackedModule()
67+
tracked_module(torch.randn(1, 5))
68+
69+
tracked_module.eval()
70+
assert not tracked_module._forward_called

0 commit comments

Comments
 (0)