Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
84 changes: 84 additions & 0 deletions merlin/models/torch/blocks/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import List, Optional, Sequence

from torch import nn

from merlin.models.torch.block import Block
from merlin.models.torch.transforms.agg import Concat, MaybeAgg


class MLPBlock(Block):
"""
Multi-Layer Perceptron (MLP) Block with custom options for activation, normalization,
dropout.

Parameters
----------
units : Sequence[int]
Sequence of integers specifying the dimensions of each linear layer.
activation : Callable, optional
Activation function to apply after each linear layer. Default is ReLU.
normalization : Union[str, nn.Module], optional
Normalization method to apply after the activation function.
Supported options are "batch_norm" or any custom `nn.Module`.
Default is None (no normalization).
dropout : Optional[float], optional
Dropout probability to apply after the normalization.
Default is None (no dropout).
pre_agg: nn.Module, optional
Whether to apply the aggregation function before the MLP layers,
when a dictionary is passed as input.
Default is MaybeAgg(Concat()).

Examples
--------
>>> mlp = MLPBlock([128, 64], activation=nn.ReLU, normalization="batch_norm", dropout=0.5)
>>> input_tensor = torch.randn(32, 100) # batch_size=32, feature_dim=100
>>> output = mlp(input_tensor)
>>> print(output.shape)
torch.Size([32, 64]) # batch_size=32, output_dim=64 (from the last layer of MLP)
>>> features = {"a": torch.randn(32, 100), "b": torch.randn(32, 100)}
>>> output = mlp(features)
torch.Size([32, 64]) # batch_size=32, output_dim=64 (from the last layer of MLP)

Raises
------
ValueError
If the normalization parameter is not supported.
"""

def __init__(
self,
units: Sequence[int],
activation=nn.ReLU,
normalization=None,
dropout: Optional[float] = None,
pre_agg: Optional[nn.Module] = MaybeAgg(Concat()),
):
modules: List[nn.Module] = []

if pre_agg is not None:
modules.append(pre_agg)

if not isinstance(units, list):
units = list(units)

for dim in units:
modules.append(nn.LazyLinear(dim))
if activation is not None:
modules.append(activation if isinstance(activation, nn.Module) else activation())

if normalization:
if normalization == "batchnorm":
modules.append(nn.LazyBatchNorm1d())
elif isinstance(normalization, nn.Module):
modules.append(normalization)
else:
raise ValueError(f"Normalization {normalization} not supported")

if dropout:
if isinstance(dropout, nn.Module):
modules.append(dropout)
else:
modules.append(nn.Dropout(dropout))

super().__init__(*modules)
7 changes: 5 additions & 2 deletions merlin/models/torch/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def insert(self, index: int, module: nn.Module, link: Optional[Link] = None):
return self

def unwrap(self) -> nn.ModuleList:
return nn.ModuleList([m.unwrap() for m in self])
return nn.ModuleList(iter(self))

def wrap_module(
self, module: nn.Module
Expand All @@ -144,10 +144,13 @@ def __len__(self) -> int:

@_copy_to_script_wrapper
def __iter__(self) -> Iterator[nn.Module]:
return iter(self.values)
return iter(m.unwrap() for m in self.values)

@_copy_to_script_wrapper
def __getitem__(self, idx: Union[slice, int]):
if isinstance(idx, slice):
return BlockContainer(*[v for v in self.values[idx]])

return self.values[idx].unwrap()

def __setitem__(self, idx: int, module: nn.Module) -> None:
Expand Down
4 changes: 3 additions & 1 deletion merlin/models/torch/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#

import inspect
from typing import Dict, Tuple
from typing import Dict, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -54,6 +54,8 @@ def is_tabular(module: torch.nn.Module) -> bool:
# Check if the annotation is a dict of tensors
if first_arg.annotation == Dict[str, torch.Tensor]:
return True
elif first_arg.annotation == Union[torch.Tensor, Dict[str, torch.Tensor]]:
return True

return False

Expand Down
Empty file.
70 changes: 70 additions & 0 deletions tests/unit/torch/blocks/test_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
import torch
from torch import nn

from merlin.models.torch.block import Block
from merlin.models.torch.blocks.mlp import MLPBlock
from merlin.models.torch.utils import module_utils


class TestMLPBlock:
def test_init(self):
units = (32, 64, 128)
mlp = MLPBlock(units)
assert isinstance(mlp, MLPBlock)
assert isinstance(mlp, Block)
assert len(mlp) == len(units) * 2 + 1

def test_activation(self):
units = [32, 64, 128]
mlp = MLPBlock(units, activation=nn.ReLU)
assert isinstance(mlp, MLPBlock)
for i, module in enumerate(mlp[1:]):
if i % 2 == 1:
assert isinstance(module, nn.ReLU)

def test_normalization_batch_norm(self):
units = [32, 64, 128]
mlp = MLPBlock(units, normalization="batchnorm")
assert isinstance(mlp, MLPBlock)
for i, module in enumerate(mlp[1:]):
if (i + 1) % 3 == 0:
assert isinstance(module, nn.LazyBatchNorm1d)

def test_normalization_custom(self):
units = [32, 64, 128]
custom_norm = nn.LayerNorm(1)
mlp = MLPBlock(units, normalization=custom_norm)
assert isinstance(mlp, MLPBlock)
for i, module in enumerate(mlp[1:]):
if i % 3 == 2:
assert isinstance(module, nn.LayerNorm)

def test_normalization_invalid(self):
units = [32, 64, 128]
with pytest.raises(ValueError):
MLPBlock(units, normalization="invalid")

def test_dropout_float(self):
units = [32, 64, 128]
mlp = MLPBlock(units, dropout=0.5)
assert isinstance(mlp, MLPBlock)
for i, module in enumerate(mlp[1:]):
if i % 3 == 2:
assert isinstance(module, nn.Dropout)
assert module.p == 0.5

def test_dropout_module(self):
units = [32, 64, 128]
mlp = MLPBlock(units, dropout=nn.Dropout(0.5))
assert isinstance(mlp, MLPBlock)
for i, module in enumerate(mlp[1:]):
if i % 3 == 2:
assert isinstance(module, nn.Dropout)
assert module.p == 0.5

def test_forward(self):
mlp = MLPBlock([32])
inputs = {"a": torch.randn(32, 2), "b": torch.randn(32, 2)}
outputs = module_utils.module_test(mlp, inputs)
assert outputs.shape == torch.Size([32, 32])
19 changes: 18 additions & 1 deletion tests/unit/torch/utils/test_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#

from typing import Dict
from typing import Dict, Union

import pytest
import torch
Expand All @@ -34,6 +34,16 @@ def forward(self, x: Dict[str, torch.Tensor]):
pass


class ModuleWithDictUnion(nn.Module):
def forward(self, x: Union[Dict[str, torch.Tensor], torch.Tensor]):
pass


class ModuleWithDictUnion2(nn.Module):
def forward(self, x: Union[torch.Tensor, Dict[str, torch.Tensor]]):
pass


class ModuleWithBatch(nn.Module):
def forward(self, x, batch=None):
pass
Expand All @@ -57,6 +67,13 @@ def test_basic(self):
assert is_tabular(module_with_dict)
assert not is_tabular(module_without_dict)

def test_union(self):
module_with_dict_union = ModuleWithDictUnion()
module_with_dict_union2 = ModuleWithDictUnion2()

assert is_tabular(module_with_dict_union)
assert is_tabular(module_with_dict_union2)


class Test_check_batch_arg:
def test_basic(self):
Expand Down