Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from merlin.models.torch.container import BlockContainer, BlockContainerDict
from merlin.models.torch.link import Link, LinkType
from merlin.models.torch.registry import registry
from merlin.models.torch.utils.schema_utils import SchemaTrackingMixin
from merlin.models.utils.registry import RegistryMixin


class Block(BlockContainer, RegistryMixin):
class Block(BlockContainer, SchemaTrackingMixin, RegistryMixin):
"""A base-class that calls it's modules sequentially.

Parameters
Expand All @@ -36,12 +37,16 @@ class Block(BlockContainer, RegistryMixin):
Variable length argument list of PyTorch modules to be contained in the block.
name : Optional[str], default = None
The name of the block. If None, no name is assigned.
track_schema : bool, default = True
If True, the schema of the output tensors are tracked.
"""

registry = registry

def __init__(self, *module: nn.Module, name: Optional[str] = None):
def __init__(self, *module: nn.Module, name: Optional[str] = None, track_schema: bool = True):
super().__init__(*module, name=name)
if track_schema:
self._register_schema_tracking_hook()

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
Expand Down Expand Up @@ -138,17 +143,16 @@ class ParallelBlock(Block):
Variable length argument list of PyTorch modules to be contained in the block.
name : Optional[str], default = None
The name of the block. If None, no name is assigned.
track_schema : bool, default = True
If True, the schema of the output tensors are tracked.
"""

def __init__(
self,
*inputs: Union[nn.Module, Dict[str, nn.Module]],
):
def __init__(self, *inputs: Union[nn.Module, Dict[str, nn.Module]], track_schema: bool = True):
pre = BlockContainer(name="pre")
branches = BlockContainerDict(*inputs)
post = BlockContainer(name="post")

super().__init__()
super().__init__(track_schema=track_schema)

self.pre = pre
self.branches = branches
Expand Down
98 changes: 98 additions & 0 deletions merlin/models/torch/utils/schema_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch

from merlin.schema import ColumnSchema, Schema, Tags


class SchemaTrackingMixin:
"""
A mixin class for PyTorch modules to track the output shapes and dtypes
of the forward pass. This is used in order to automatically generate
the output-schema.

It registers a hook to capture this information and
provides methods to access the output schema, as well as to set the module
in training or evaluation mode.
"""

def __init__(self):
super().__init__()
self._register_schema_tracking_hook()

def _post_forward_hook(self, module, input, output):
"""Hook function to be called after the forward pass of the module.

Parameters
----------
module : torch.nn.Module
The module for which the forward pass was called.
input : tuple
The input arguments passed to the forward method.
output : torch.Tensor or dict
The output of the forward method.
"""
if not module._forward_called:
if isinstance(output, dict):
for key, value in output.items():
module._output_shapes[key] = value.shape
module._output_dtypes[key] = value.dtype
else:
module._output_shapes["output"] = output.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this "output" key is for the case where a module outputs a single tensor? Would something that uses this output schema depend on the name of this column name elsewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, not sure actually. Maybe we should extract it in a constant somewhere?

module._output_dtypes["output"] = output.dtype
module._forward_called = True
module._handle.remove()

def _register_schema_tracking_hook(self):
"""
Register the post forward hook to the module.
"""
self._forward_called = False
self._handle = None
self._output_shapes = {}
self._output_dtypes = {}

if self._handle is None:
self._handle = self.register_forward_hook(self._post_forward_hook)

def output_schema(self) -> Schema:
"""Get the output schema of the module.

Returns
-------
Schema
The output schema of the module.

Raises
------
RuntimeError
If forward() has not been called before calling this method.
"""

if not hasattr(self, "_output_shapes"):
raise RuntimeError(
"Schema-tracking hook not registered, use `_register_schema_tracking_hook`."
)

if not self._forward_called:
raise RuntimeError("forward() must be called before output_schema() can be called.")

columns = []

for name, shape in self._output_shapes.items():
dtype = self._output_dtypes[name]
dims = (None,) + tuple(shape)
tags = None

if len(shape) > 1 and dtype != torch.int32:
tags = [Tags.EMBEDDING]

columns.append(ColumnSchema(name, dims=dims, tags=tags, dtype=dtype))

return Schema(columns)

def train(self, mode=True):
self._register_schema_tracking_hook()
return super().train(mode)

def eval(self):
self._register_schema_tracking_hook()
return super().eval()
23 changes: 23 additions & 0 deletions tests/unit/torch/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from merlin.models.torch.block import Block, ParallelBlock
from merlin.models.torch.container import BlockContainer, BlockContainerDict
from merlin.models.torch.utils import module_utils
from merlin.schema import Tags


class PlusOne(nn.Module):
Expand All @@ -50,6 +51,14 @@ def test_identity(self):

assert torch.equal(inputs, outputs)

schema = block.output_schema()
assert schema.first.dtype.name == str(outputs.dtype).split(".")[-1]

def test_no_schema_tracking(self):
block = Block(track_schema=False)
with pytest.raises(RuntimeError, match="Schema-tracking hook not registered"):
block.output_schema()

def test_insertion(self):
block = Block()
block.prepend(PlusOne())
Expand Down Expand Up @@ -148,6 +157,20 @@ def test_forward_dict_duplicate(self):
with pytest.raises(RuntimeError):
pb(inputs)

def test_schema_tracking(self):
pb = ParallelBlock({"a": PlusOne(), "b": PlusOne()})

inputs = torch.randn(1, 3)
outputs = pb(inputs)

schema = pb.output_schema()

for name in outputs:
assert name in schema.column_names
assert schema[name].dtype.name == str(outputs[name].dtype).split(".")[-1]

assert len(schema.select_by_tag(Tags.EMBEDDING)) == 2

def test_forward_tuple(self):
inputs = torch.randn(1, 3)
pb = ParallelBlock({"test": PlusOneTuple()})
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/torch/utils/test_schema_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
import torch
from torch import nn

from merlin.models.torch.utils.module_utils import module_test
from merlin.models.torch.utils.schema_utils import SchemaTrackingMixin
from merlin.schema import Schema, Tags


class TrackedModule(SchemaTrackingMixin, nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.LazyLinear(10)

def forward(self, x: torch.Tensor):
return self.linear(x)


class TrackedDictModule(SchemaTrackingMixin, nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.LazyLinear(10)

def forward(self, x: torch.Tensor):
return {"a": self.linear(x), "b": self.linear(x)}


class TestSchemaTrackingMixin:
def test_tensor(self):
inputs = torch.randn(1, 5)
tracked_module = TrackedModule()
module_test(tracked_module, inputs)

schema = tracked_module.output_schema()
assert isinstance(schema, Schema)
assert len(schema) == 1
assert len(schema.select_by_tag(Tags.EMBEDDING)) == 1

def test_dict(self):
inputs = torch.randn(1, 5)
tracked_module = TrackedDictModule()

outputs = tracked_module(inputs)
traced_outputs = module_test(tracked_module, inputs)
assert torch.equal(outputs["a"], traced_outputs["a"])
assert torch.equal(outputs["b"], traced_outputs["b"])

schema = tracked_module.output_schema()
assert isinstance(schema, Schema)
assert len(schema) == 2
assert len(schema.select_by_tag(Tags.EMBEDDING)) == 2

def test_exception(self):
tracked_module = TrackedModule()
with pytest.raises(RuntimeError):
tracked_module.output_schema()

def test_train(self):
tracked_module = TrackedModule()
tracked_module(torch.randn(1, 5))

tracked_module.train()
assert not tracked_module._forward_called

def test_eval(self):
tracked_module = TrackedModule()
tracked_module(torch.randn(1, 5))

tracked_module.eval()
assert not tracked_module._forward_called