Skip to content

Commit 92833fa

Browse files
edknvmarcromeyn
andauthored
Add Model class (#1126)
* Add Model class * Use BlockContainer in Models class * Add unit tests * Add docstrings * Add module_utils unit tests * Remove future work * move initialize() to module_utils * handle batch in training_step * Add output_schema() * check if model outputs have no targets when no target is provided * put loss and metrics on the same device * add docstrings to module_utils functions * move metric device setting to initialize * update logic for using model output targets --------- Co-authored-by: Marc Romeyn <marcromeyn@gmail.com>
1 parent 6ea8828 commit 92833fa

File tree

9 files changed

+777
-5
lines changed

9 files changed

+777
-5
lines changed

merlin/models/torch/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from merlin.models.torch.batch import Batch, Sequence
1818
from merlin.models.torch.block import Block, ParallelBlock
19+
from merlin.models.torch.blocks.mlp import MLPBlock
20+
from merlin.models.torch.models.base import Model
1921
from merlin.models.torch.outputs.base import ModelOutput
2022
from merlin.models.torch.outputs.classification import BinaryOutput
2123
from merlin.models.torch.outputs.regression import RegressionOutput
@@ -26,8 +28,10 @@
2628
"Batch",
2729
"BinaryOutput",
2830
"Block",
29-
"ParallelBlock",
31+
"MLPBlock",
32+
"Model",
3033
"ModelOutput",
34+
"ParallelBlock",
3135
"Sequence",
3236
"RegressionOutput",
3337
"RouterBlock",

merlin/models/torch/block.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
class Block(BlockContainer, SchemaTrackingMixin, RegistryMixin):
32-
"""A base-class that calls it's modules sequentially.
32+
"""A base-class that calls its modules sequentially.
3333
3434
Parameters
3535
----------
@@ -114,7 +114,7 @@ def copy(self) -> "Block":
114114

115115

116116
class ParallelBlock(Block):
117-
"""A base-class that calls it's modules in parallel.
117+
"""A base-class that calls its modules in parallel.
118118
119119
A ParallelBlock contains multiple branches that will be executed
120120
in parallel. Each branch can contain multiple modules, and
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#
2+
# Copyright (c) 2023, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#

merlin/models/torch/models/base.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
#
2+
# Copyright (c) 2023, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
from functools import reduce
17+
from typing import Dict, List, Optional, Sequence, Union
18+
19+
import torch
20+
from pytorch_lightning import LightningModule
21+
from torch import nn
22+
23+
from merlin.dataloader.torch import Loader
24+
from merlin.io import Dataset
25+
from merlin.models.torch.batch import Batch
26+
from merlin.models.torch.block import Block
27+
from merlin.models.torch.container import BlockContainer
28+
from merlin.models.torch.outputs.base import ModelOutput
29+
from merlin.models.torch.utils import module_utils
30+
from merlin.models.utils.registry import camelcase_to_snakecase
31+
from merlin.schema import Schema
32+
33+
34+
class Model(Block, LightningModule):
35+
"""
36+
Merlin Model class.
37+
38+
The Model class extends from both the Block and LightningModule classes. It
39+
allows for easy construction of models using pre-defined blocks.
40+
41+
Parameters
42+
----------
43+
*blocks: nn.Module
44+
One or more blocks that make up the core functionality of the model.
45+
schema: Schema, optional
46+
A Merlin schema. Default is None.
47+
optimizer: torch.optim.Optimizer, optional
48+
A PyTorch optimizer from the PyTorch library (or any custom optimizer
49+
that follows the same API). Default is Adam optimizer.
50+
51+
Example usage
52+
-------------
53+
>>> model = Model(
54+
... TabularInputBlock(schema),
55+
... MLPBlock([32, 16]),
56+
... BinaryOutput(schema.select_by_tag(Tags.TARGET).first),
57+
... )
58+
... trainer = Trainer(max_epochs=1)
59+
... with Loader(dataset, batch_size=16) as loader:
60+
... model.initialize(loader)
61+
... trainer.fit(model, loader)
62+
"""
63+
64+
def __init__(
65+
self,
66+
*blocks: nn.Module,
67+
schema: Optional[Schema] = None,
68+
optimizer=torch.optim.Adam,
69+
):
70+
"""Initializes `Model` class"""
71+
super().__init__()
72+
self.schema = schema
73+
74+
self.pre = BlockContainer(name="pre")
75+
self.blocks = BlockContainer(name="blocks")
76+
for block in blocks:
77+
self.blocks.append(block)
78+
self.post = BlockContainer(name="post")
79+
80+
self.optimizer = optimizer
81+
82+
def initialize(self, data: Union[Dataset, Loader, Batch]):
83+
"""Initializes the model based on a given data set."""
84+
return module_utils.initialize(self, data)
85+
86+
def forward(
87+
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
88+
):
89+
"""Performs a forward pass through the model."""
90+
outputs = inputs
91+
for pre in self.pre.values:
92+
outputs = pre(outputs, batch=batch)
93+
for block in self.blocks.values:
94+
outputs = block(outputs, batch=batch)
95+
for post in self.post.values:
96+
outputs = post(outputs, batch=batch)
97+
return outputs
98+
99+
def training_step(self, batch, batch_idx):
100+
"""Performs a training step with a single batch."""
101+
del batch_idx
102+
if isinstance(batch, Batch):
103+
features = batch.features
104+
targets = batch.targets
105+
else:
106+
features, targets = batch
107+
108+
predictions = self(features, batch=Batch(features, targets))
109+
110+
loss_and_metrics = compute_loss(predictions, targets, self.model_outputs())
111+
for name, value in loss_and_metrics.items():
112+
self.log(f"train_{name}", value)
113+
114+
return loss_and_metrics["loss"]
115+
116+
def configure_optimizers(self):
117+
"""Configures the optimizer for the model."""
118+
return self.optimizer(self.parameters())
119+
120+
def model_outputs(self) -> List[ModelOutput]:
121+
"""Finds all instances of `ModelOutput` in the model."""
122+
return module_utils.find_all_instances(self, ModelOutput)
123+
124+
def first(self) -> nn.Module:
125+
"""Returns the first block in the model."""
126+
return self.blocks.values[0]
127+
128+
def last(self) -> nn.Module:
129+
"""Returns the last block in the model."""
130+
return self.blocks.values[-1]
131+
132+
def input_schema(self) -> Schema:
133+
"""Returns the input schema of the model."""
134+
if self.schema:
135+
return self.schema
136+
# TODO: Implement logic when TabularInputBlock is available.
137+
return Schema([])
138+
139+
def output_schema(self) -> Schema:
140+
output_schemas = []
141+
for child in module_utils.get_all_children(self):
142+
if hasattr(child, "output_schema"):
143+
output_schemas.append(child.output_schema())
144+
145+
if not output_schemas:
146+
raise RuntimeError("No output schema found")
147+
148+
return reduce(lambda a, b: a + b, output_schemas) # type: ignore
149+
150+
151+
def compute_loss(
152+
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
153+
targets: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]],
154+
model_outputs: Sequence[ModelOutput],
155+
compute_metrics: bool = True,
156+
) -> Dict[str, torch.Tensor]:
157+
"""Compute the loss and metrics for the given model outputs.
158+
159+
This function takes in predictions and targets, and a list of model
160+
outputs. It computes the loss using the loss function of each model output
161+
and averages it. If `compute_metrics` is set to True, it also computes the
162+
metrics defined in each model output.
163+
164+
Parameters
165+
----------
166+
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]]
167+
The predictions from the model.
168+
targets: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]
169+
The ground truth targets.
170+
model_outputs: Sequence[ModelOutput]
171+
A list of model outputs. Each model output must have a defined loss
172+
function.
173+
compute_metrics: bool, optional
174+
Whether to compute metrics defined in each model output. Default: True.
175+
176+
Returns
177+
-------
178+
Dict[str, torch.Tensor]
179+
A dictionary containing the loss and the computed metrics (if any).
180+
181+
Raises
182+
------
183+
RuntimeError: If no model outputs are provided, or if multiple model
184+
outputs are provided but only one set of targets is given.
185+
186+
Example usage
187+
-------------
188+
>>> predictions = torch.tensor([0.2, 0.3, 0.6, 0.8])
189+
>>> targets = torch.tensor([1.0, 0.0, 1.0, 0.0], dtype=torch.float32)
190+
>>> binary_output = mm.BinaryOutput(ColumnSchema("target"))
191+
>>> results = compute_loss(predictions, targets, [binary_output])
192+
>>> results["loss"]
193+
tensor(0.7653)
194+
>>> results["binary_accuracy"]
195+
tensor(0.5000)
196+
"""
197+
if len(model_outputs) < 1:
198+
raise RuntimeError("No model outputs found.")
199+
200+
results = {"loss": torch.tensor(0.0)}
201+
for model_out in model_outputs:
202+
name = model_out.output_schema.first.name
203+
204+
if targets is None or (isinstance(targets, dict) and name not in targets):
205+
if not hasattr(model_out, "target"):
206+
raise ValueError(f"'{model_out.__class__.__name__}' has no target.")
207+
if isinstance(predictions, dict):
208+
pred_col = predictions[name]
209+
else:
210+
pred_col = predictions
211+
_targets = torch.ones_like(pred_col) * model_out.target
212+
elif isinstance(targets, dict):
213+
_targets = targets[name]
214+
elif isinstance(targets, torch.Tensor):
215+
_targets = targets
216+
else:
217+
raise ValueError(f"Unknown 'targets' type: {type(targets)}")
218+
219+
if isinstance(predictions, dict):
220+
if name not in predictions:
221+
raise RuntimeError(f"Column '{name}' not found in predictions")
222+
_predictions = predictions[name]
223+
elif isinstance(predictions, torch.Tensor):
224+
_predictions = predictions
225+
else:
226+
raise ValueError(f"Unknown 'predictions' type: {type(predictions)}")
227+
228+
results["loss"] = results["loss"] + model_out.loss(_predictions, _targets) / len(
229+
model_outputs
230+
)
231+
232+
if not compute_metrics:
233+
continue
234+
235+
for metric in model_out.metrics:
236+
metric_name = camelcase_to_snakecase(metric.__class__.__name__)
237+
results[metric_name] = metric(_predictions, _targets)
238+
return results

0 commit comments

Comments
 (0)