-
Notifications
You must be signed in to change notification settings - Fork 54
Adding SchemaTrackingMixin #1109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
708ef75
Adding SchemaTrackingMixin
marcromeyn f8a9985
Merge branch 'main' into torch/track-schema
marcromeyn 352ad32
Small fix in test_tensor
marcromeyn c0e2fd3
:Merge branch 'torch/track-schema' of github.com:NVIDIA-Merlin/models…
marcromeyn 2cbc9fd
Add schema-tracking to Block
marcromeyn cf3ee1a
Merge branch 'main' into torch/track-schema
marcromeyn 300d01a
Merge branch 'main' into torch/track-schema
marcromeyn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| import torch | ||
|
|
||
| from merlin.schema import ColumnSchema, Schema, Tags | ||
|
|
||
|
|
||
| class SchemaTrackingMixin: | ||
| """ | ||
| A mixin class for PyTorch modules to track the output shapes and dtypes | ||
| of the forward pass. This is used in order to automatically generate | ||
| the output-schema. | ||
| It registers a hook to capture this information and | ||
| provides methods to access the output schema, as well as to set the module | ||
| in training or evaluation mode. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self._register_schema_tracking_hook() | ||
|
|
||
| def _post_forward_hook(self, module, input, output): | ||
| """Hook function to be called after the forward pass of the module. | ||
| Parameters | ||
| ---------- | ||
| module : torch.nn.Module | ||
| The module for which the forward pass was called. | ||
| input : tuple | ||
| The input arguments passed to the forward method. | ||
| output : torch.Tensor or dict | ||
| The output of the forward method. | ||
| """ | ||
| if not module._forward_called: | ||
| if isinstance(output, dict): | ||
| for key, value in output.items(): | ||
| module._output_shapes[key] = value.shape | ||
| module._output_dtypes[key] = value.dtype | ||
| else: | ||
| module._output_shapes["output"] = output.shape | ||
| module._output_dtypes["output"] = output.dtype | ||
| module._forward_called = True | ||
| module._handle.remove() | ||
|
|
||
| def _register_schema_tracking_hook(self): | ||
| """ | ||
| Register the post forward hook to the module. | ||
| """ | ||
| self._forward_called = False | ||
| self._handle = None | ||
| self._output_shapes = {} | ||
| self._output_dtypes = {} | ||
|
|
||
| if self._handle is None: | ||
| self._handle = self.register_forward_hook(self._post_forward_hook) | ||
|
|
||
| def output_schema(self) -> Schema: | ||
marcromeyn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Get the output schema of the module. | ||
| Returns | ||
| ------- | ||
| Schema | ||
| The output schema of the module. | ||
| Raises | ||
| ------ | ||
| RuntimeError | ||
| If forward() has not been called before calling this method. | ||
| """ | ||
|
|
||
| if not hasattr(self, "_output_shapes"): | ||
| raise RuntimeError( | ||
| "Schema-tracking hook not registered, use `_register_schema_tracking_hook`." | ||
| ) | ||
|
|
||
| if not self._forward_called: | ||
| raise RuntimeError("forward() must be called before output_schema() can be called.") | ||
|
|
||
| columns = [] | ||
|
|
||
| for name, shape in self._output_shapes.items(): | ||
| dtype = self._output_dtypes[name] | ||
| dims = (None,) + tuple(shape) | ||
| tags = None | ||
|
|
||
| if len(shape) > 1 and dtype != torch.int32: | ||
| tags = [Tags.EMBEDDING] | ||
marcromeyn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| columns.append(ColumnSchema(name, dims=dims, tags=tags, dtype=dtype)) | ||
|
|
||
| return Schema(columns) | ||
|
|
||
| def train(self, mode=True): | ||
| self._register_schema_tracking_hook() | ||
| return super().train(mode) | ||
|
|
||
| def eval(self): | ||
| self._register_schema_tracking_hook() | ||
| return super().eval() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| import pytest | ||
| import torch | ||
| from torch import nn | ||
|
|
||
| from merlin.models.torch.utils.module_utils import module_test | ||
| from merlin.models.torch.utils.schema_utils import SchemaTrackingMixin | ||
| from merlin.schema import Schema, Tags | ||
|
|
||
|
|
||
| class TrackedModule(SchemaTrackingMixin, nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.linear = nn.LazyLinear(10) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| return self.linear(x) | ||
|
|
||
|
|
||
| class TrackedDictModule(SchemaTrackingMixin, nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.linear = nn.LazyLinear(10) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| return {"a": self.linear(x), "b": self.linear(x)} | ||
|
|
||
|
|
||
| class TestSchemaTrackingMixin: | ||
| def test_tensor(self): | ||
| inputs = torch.randn(1, 5) | ||
| tracked_module = TrackedModule() | ||
| module_test(tracked_module, inputs) | ||
|
|
||
| schema = tracked_module.output_schema() | ||
| assert isinstance(schema, Schema) | ||
| assert len(schema) == 1 | ||
| assert len(schema.select_by_tag(Tags.EMBEDDING)) == 1 | ||
|
|
||
| def test_dict(self): | ||
| inputs = torch.randn(1, 5) | ||
| tracked_module = TrackedDictModule() | ||
|
|
||
| outputs = tracked_module(inputs) | ||
| traced_outputs = module_test(tracked_module, inputs) | ||
| assert torch.equal(outputs["a"], traced_outputs["a"]) | ||
| assert torch.equal(outputs["b"], traced_outputs["b"]) | ||
|
|
||
| schema = tracked_module.output_schema() | ||
| assert isinstance(schema, Schema) | ||
| assert len(schema) == 2 | ||
| assert len(schema.select_by_tag(Tags.EMBEDDING)) == 2 | ||
|
|
||
| def test_exception(self): | ||
| tracked_module = TrackedModule() | ||
| with pytest.raises(RuntimeError): | ||
| tracked_module.output_schema() | ||
|
|
||
| def test_train(self): | ||
| tracked_module = TrackedModule() | ||
| tracked_module(torch.randn(1, 5)) | ||
|
|
||
| tracked_module.train() | ||
| assert not tracked_module._forward_called | ||
|
|
||
| def test_eval(self): | ||
| tracked_module = TrackedModule() | ||
| tracked_module(torch.randn(1, 5)) | ||
|
|
||
| tracked_module.eval() | ||
| assert not tracked_module._forward_called |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this
"output"key is for the case where a module outputs a single tensor? Would something that uses this output schema depend on the name of this column name elsewhere?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, not sure actually. Maybe we should extract it in a constant somewhere?