|
| 1 | +# |
| 2 | +# Copyright (c) 2023, NVIDIA CORPORATION. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# |
| 16 | +from functools import reduce |
| 17 | +from typing import Dict, List, Optional, Sequence, Union |
| 18 | + |
| 19 | +import torch |
| 20 | +from pytorch_lightning import LightningModule |
| 21 | +from torch import nn |
| 22 | + |
| 23 | +from merlin.dataloader.torch import Loader |
| 24 | +from merlin.io import Dataset |
| 25 | +from merlin.models.torch.batch import Batch |
| 26 | +from merlin.models.torch.block import Block |
| 27 | +from merlin.models.torch.container import BlockContainer |
| 28 | +from merlin.models.torch.outputs.base import ModelOutput |
| 29 | +from merlin.models.torch.utils import module_utils |
| 30 | +from merlin.models.utils.registry import camelcase_to_snakecase |
| 31 | +from merlin.schema import Schema |
| 32 | + |
| 33 | + |
| 34 | +class Model(Block, LightningModule): |
| 35 | + """ |
| 36 | + Merlin Model class. |
| 37 | +
|
| 38 | + The Model class extends from both the Block and LightningModule classes. It |
| 39 | + allows for easy construction of models using pre-defined blocks. |
| 40 | +
|
| 41 | + Parameters |
| 42 | + ---------- |
| 43 | + *blocks: nn.Module |
| 44 | + One or more blocks that make up the core functionality of the model. |
| 45 | + schema: Schema, optional |
| 46 | + A Merlin schema. Default is None. |
| 47 | + optimizer: torch.optim.Optimizer, optional |
| 48 | + A PyTorch optimizer from the PyTorch library (or any custom optimizer |
| 49 | + that follows the same API). Default is Adam optimizer. |
| 50 | +
|
| 51 | + Example usage |
| 52 | + ------------- |
| 53 | + >>> model = Model( |
| 54 | + ... TabularInputBlock(schema), |
| 55 | + ... MLPBlock([32, 16]), |
| 56 | + ... BinaryOutput(schema.select_by_tag(Tags.TARGET).first), |
| 57 | + ... ) |
| 58 | + ... trainer = Trainer(max_epochs=1) |
| 59 | + ... with Loader(dataset, batch_size=16) as loader: |
| 60 | + ... model.initialize(loader) |
| 61 | + ... trainer.fit(model, loader) |
| 62 | + """ |
| 63 | + |
| 64 | + def __init__( |
| 65 | + self, |
| 66 | + *blocks: nn.Module, |
| 67 | + schema: Optional[Schema] = None, |
| 68 | + optimizer=torch.optim.Adam, |
| 69 | + ): |
| 70 | + """Initializes `Model` class""" |
| 71 | + super().__init__() |
| 72 | + self.schema = schema |
| 73 | + |
| 74 | + self.pre = BlockContainer(name="pre") |
| 75 | + self.blocks = BlockContainer(name="blocks") |
| 76 | + for block in blocks: |
| 77 | + self.blocks.append(block) |
| 78 | + self.post = BlockContainer(name="post") |
| 79 | + |
| 80 | + self.optimizer = optimizer |
| 81 | + |
| 82 | + def initialize(self, data: Union[Dataset, Loader, Batch]): |
| 83 | + """Initializes the model based on a given data set.""" |
| 84 | + return module_utils.initialize(self, data) |
| 85 | + |
| 86 | + def forward( |
| 87 | + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None |
| 88 | + ): |
| 89 | + """Performs a forward pass through the model.""" |
| 90 | + outputs = inputs |
| 91 | + for pre in self.pre.values: |
| 92 | + outputs = pre(outputs, batch=batch) |
| 93 | + for block in self.blocks.values: |
| 94 | + outputs = block(outputs, batch=batch) |
| 95 | + for post in self.post.values: |
| 96 | + outputs = post(outputs, batch=batch) |
| 97 | + return outputs |
| 98 | + |
| 99 | + def training_step(self, batch, batch_idx): |
| 100 | + """Performs a training step with a single batch.""" |
| 101 | + del batch_idx |
| 102 | + if isinstance(batch, Batch): |
| 103 | + features = batch.features |
| 104 | + targets = batch.targets |
| 105 | + else: |
| 106 | + features, targets = batch |
| 107 | + |
| 108 | + predictions = self(features, batch=Batch(features, targets)) |
| 109 | + |
| 110 | + loss_and_metrics = compute_loss(predictions, targets, self.model_outputs()) |
| 111 | + for name, value in loss_and_metrics.items(): |
| 112 | + self.log(f"train_{name}", value) |
| 113 | + |
| 114 | + return loss_and_metrics["loss"] |
| 115 | + |
| 116 | + def configure_optimizers(self): |
| 117 | + """Configures the optimizer for the model.""" |
| 118 | + return self.optimizer(self.parameters()) |
| 119 | + |
| 120 | + def model_outputs(self) -> List[ModelOutput]: |
| 121 | + """Finds all instances of `ModelOutput` in the model.""" |
| 122 | + return module_utils.find_all_instances(self, ModelOutput) |
| 123 | + |
| 124 | + def first(self) -> nn.Module: |
| 125 | + """Returns the first block in the model.""" |
| 126 | + return self.blocks.values[0] |
| 127 | + |
| 128 | + def last(self) -> nn.Module: |
| 129 | + """Returns the last block in the model.""" |
| 130 | + return self.blocks.values[-1] |
| 131 | + |
| 132 | + def input_schema(self) -> Schema: |
| 133 | + """Returns the input schema of the model.""" |
| 134 | + if self.schema: |
| 135 | + return self.schema |
| 136 | + # TODO: Implement logic when TabularInputBlock is available. |
| 137 | + return Schema([]) |
| 138 | + |
| 139 | + def output_schema(self) -> Schema: |
| 140 | + output_schemas = [] |
| 141 | + for child in module_utils.get_all_children(self): |
| 142 | + if hasattr(child, "output_schema"): |
| 143 | + output_schemas.append(child.output_schema()) |
| 144 | + |
| 145 | + if not output_schemas: |
| 146 | + raise RuntimeError("No output schema found") |
| 147 | + |
| 148 | + return reduce(lambda a, b: a + b, output_schemas) # type: ignore |
| 149 | + |
| 150 | + |
| 151 | +def compute_loss( |
| 152 | + predictions: Union[torch.Tensor, Dict[str, torch.Tensor]], |
| 153 | + targets: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]], |
| 154 | + model_outputs: Sequence[ModelOutput], |
| 155 | + compute_metrics: bool = True, |
| 156 | +) -> Dict[str, torch.Tensor]: |
| 157 | + """Compute the loss and metrics for the given model outputs. |
| 158 | +
|
| 159 | + This function takes in predictions and targets, and a list of model |
| 160 | + outputs. It computes the loss using the loss function of each model output |
| 161 | + and averages it. If `compute_metrics` is set to True, it also computes the |
| 162 | + metrics defined in each model output. |
| 163 | +
|
| 164 | + Parameters |
| 165 | + ---------- |
| 166 | + predictions: Union[torch.Tensor, Dict[str, torch.Tensor]] |
| 167 | + The predictions from the model. |
| 168 | + targets: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] |
| 169 | + The ground truth targets. |
| 170 | + model_outputs: Sequence[ModelOutput] |
| 171 | + A list of model outputs. Each model output must have a defined loss |
| 172 | + function. |
| 173 | + compute_metrics: bool, optional |
| 174 | + Whether to compute metrics defined in each model output. Default: True. |
| 175 | +
|
| 176 | + Returns |
| 177 | + ------- |
| 178 | + Dict[str, torch.Tensor] |
| 179 | + A dictionary containing the loss and the computed metrics (if any). |
| 180 | +
|
| 181 | + Raises |
| 182 | + ------ |
| 183 | + RuntimeError: If no model outputs are provided, or if multiple model |
| 184 | + outputs are provided but only one set of targets is given. |
| 185 | +
|
| 186 | + Example usage |
| 187 | + ------------- |
| 188 | + >>> predictions = torch.tensor([0.2, 0.3, 0.6, 0.8]) |
| 189 | + >>> targets = torch.tensor([1.0, 0.0, 1.0, 0.0], dtype=torch.float32) |
| 190 | + >>> binary_output = mm.BinaryOutput(ColumnSchema("target")) |
| 191 | + >>> results = compute_loss(predictions, targets, [binary_output]) |
| 192 | + >>> results["loss"] |
| 193 | + tensor(0.7653) |
| 194 | + >>> results["binary_accuracy"] |
| 195 | + tensor(0.5000) |
| 196 | + """ |
| 197 | + if len(model_outputs) < 1: |
| 198 | + raise RuntimeError("No model outputs found.") |
| 199 | + |
| 200 | + results = {"loss": torch.tensor(0.0)} |
| 201 | + for model_out in model_outputs: |
| 202 | + name = model_out.output_schema.first.name |
| 203 | + |
| 204 | + if targets is None or (isinstance(targets, dict) and name not in targets): |
| 205 | + if not hasattr(model_out, "target"): |
| 206 | + raise ValueError(f"'{model_out.__class__.__name__}' has no target.") |
| 207 | + if isinstance(predictions, dict): |
| 208 | + pred_col = predictions[name] |
| 209 | + else: |
| 210 | + pred_col = predictions |
| 211 | + _targets = torch.ones_like(pred_col) * model_out.target |
| 212 | + elif isinstance(targets, dict): |
| 213 | + _targets = targets[name] |
| 214 | + elif isinstance(targets, torch.Tensor): |
| 215 | + _targets = targets |
| 216 | + else: |
| 217 | + raise ValueError(f"Unknown 'targets' type: {type(targets)}") |
| 218 | + |
| 219 | + if isinstance(predictions, dict): |
| 220 | + if name not in predictions: |
| 221 | + raise RuntimeError(f"Column '{name}' not found in predictions") |
| 222 | + _predictions = predictions[name] |
| 223 | + elif isinstance(predictions, torch.Tensor): |
| 224 | + _predictions = predictions |
| 225 | + else: |
| 226 | + raise ValueError(f"Unknown 'predictions' type: {type(predictions)}") |
| 227 | + |
| 228 | + results["loss"] = results["loss"] + model_out.loss(_predictions, _targets) / len( |
| 229 | + model_outputs |
| 230 | + ) |
| 231 | + |
| 232 | + if not compute_metrics: |
| 233 | + continue |
| 234 | + |
| 235 | + for metric in model_out.metrics: |
| 236 | + metric_name = camelcase_to_snakecase(metric.__class__.__name__) |
| 237 | + results[metric_name] = metric(_predictions, _targets) |
| 238 | + return results |
0 commit comments