Skip to content

Commit 1997686

Browse files
marcromeynedknv
andauthored
Adding MLPBlock (#1093)
* Increase test-coverage * Adding ParallelBlock * Adding MLPBlock * Quick-fix activation in MLPBlock * Fixing merge conflicts --------- Co-authored-by: edknv <109497216+edknv@users.noreply.github.com>
1 parent b6d6645 commit 1997686

File tree

7 files changed

+180
-4
lines changed

7 files changed

+180
-4
lines changed

merlin/models/torch/blocks/__init__.py

Whitespace-only changes.

merlin/models/torch/blocks/mlp.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from typing import List, Optional, Sequence
2+
3+
from torch import nn
4+
5+
from merlin.models.torch.block import Block
6+
from merlin.models.torch.transforms.agg import Concat, MaybeAgg
7+
8+
9+
class MLPBlock(Block):
10+
"""
11+
Multi-Layer Perceptron (MLP) Block with custom options for activation, normalization,
12+
dropout.
13+
14+
Parameters
15+
----------
16+
units : Sequence[int]
17+
Sequence of integers specifying the dimensions of each linear layer.
18+
activation : Callable, optional
19+
Activation function to apply after each linear layer. Default is ReLU.
20+
normalization : Union[str, nn.Module], optional
21+
Normalization method to apply after the activation function.
22+
Supported options are "batch_norm" or any custom `nn.Module`.
23+
Default is None (no normalization).
24+
dropout : Optional[float], optional
25+
Dropout probability to apply after the normalization.
26+
Default is None (no dropout).
27+
pre_agg: nn.Module, optional
28+
Whether to apply the aggregation function before the MLP layers,
29+
when a dictionary is passed as input.
30+
Default is MaybeAgg(Concat()).
31+
32+
Examples
33+
--------
34+
>>> mlp = MLPBlock([128, 64], activation=nn.ReLU, normalization="batch_norm", dropout=0.5)
35+
>>> input_tensor = torch.randn(32, 100) # batch_size=32, feature_dim=100
36+
>>> output = mlp(input_tensor)
37+
>>> print(output.shape)
38+
torch.Size([32, 64]) # batch_size=32, output_dim=64 (from the last layer of MLP)
39+
>>> features = {"a": torch.randn(32, 100), "b": torch.randn(32, 100)}
40+
>>> output = mlp(features)
41+
torch.Size([32, 64]) # batch_size=32, output_dim=64 (from the last layer of MLP)
42+
43+
Raises
44+
------
45+
ValueError
46+
If the normalization parameter is not supported.
47+
"""
48+
49+
def __init__(
50+
self,
51+
units: Sequence[int],
52+
activation=nn.ReLU,
53+
normalization=None,
54+
dropout: Optional[float] = None,
55+
pre_agg: Optional[nn.Module] = MaybeAgg(Concat()),
56+
):
57+
modules: List[nn.Module] = []
58+
59+
if pre_agg is not None:
60+
modules.append(pre_agg)
61+
62+
if not isinstance(units, list):
63+
units = list(units)
64+
65+
for dim in units:
66+
modules.append(nn.LazyLinear(dim))
67+
if activation is not None:
68+
modules.append(activation if isinstance(activation, nn.Module) else activation())
69+
70+
if normalization:
71+
if normalization == "batchnorm":
72+
modules.append(nn.LazyBatchNorm1d())
73+
elif isinstance(normalization, nn.Module):
74+
modules.append(normalization)
75+
else:
76+
raise ValueError(f"Normalization {normalization} not supported")
77+
78+
if dropout:
79+
if isinstance(dropout, nn.Module):
80+
modules.append(dropout)
81+
else:
82+
modules.append(nn.Dropout(dropout))
83+
84+
super().__init__(*modules)

merlin/models/torch/container.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def insert(self, index: int, module: nn.Module, link: Optional[Link] = None):
119119
return self
120120

121121
def unwrap(self) -> nn.ModuleList:
122-
return nn.ModuleList([m.unwrap() for m in self])
122+
return nn.ModuleList(iter(self))
123123

124124
def wrap_module(
125125
self, module: nn.Module
@@ -144,10 +144,13 @@ def __len__(self) -> int:
144144

145145
@_copy_to_script_wrapper
146146
def __iter__(self) -> Iterator[nn.Module]:
147-
return iter(self.values)
147+
return iter(m.unwrap() for m in self.values)
148148

149149
@_copy_to_script_wrapper
150150
def __getitem__(self, idx: Union[slice, int]):
151+
if isinstance(idx, slice):
152+
return BlockContainer(*[v for v in self.values[idx]])
153+
151154
return self.values[idx].unwrap()
152155

153156
def __setitem__(self, idx: int, module: nn.Module) -> None:

merlin/models/torch/utils/module_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#
1616

1717
import inspect
18-
from typing import Dict, Tuple
18+
from typing import Dict, Tuple, Union
1919

2020
import torch
2121
import torch.nn as nn
@@ -54,6 +54,8 @@ def is_tabular(module: torch.nn.Module) -> bool:
5454
# Check if the annotation is a dict of tensors
5555
if first_arg.annotation == Dict[str, torch.Tensor]:
5656
return True
57+
elif first_arg.annotation == Union[torch.Tensor, Dict[str, torch.Tensor]]:
58+
return True
5759

5860
return False
5961

tests/unit/torch/blocks/__init__.py

Whitespace-only changes.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
import torch
3+
from torch import nn
4+
5+
from merlin.models.torch.block import Block
6+
from merlin.models.torch.blocks.mlp import MLPBlock
7+
from merlin.models.torch.utils import module_utils
8+
9+
10+
class TestMLPBlock:
11+
def test_init(self):
12+
units = (32, 64, 128)
13+
mlp = MLPBlock(units)
14+
assert isinstance(mlp, MLPBlock)
15+
assert isinstance(mlp, Block)
16+
assert len(mlp) == len(units) * 2 + 1
17+
18+
def test_activation(self):
19+
units = [32, 64, 128]
20+
mlp = MLPBlock(units, activation=nn.ReLU)
21+
assert isinstance(mlp, MLPBlock)
22+
for i, module in enumerate(mlp[1:]):
23+
if i % 2 == 1:
24+
assert isinstance(module, nn.ReLU)
25+
26+
def test_normalization_batch_norm(self):
27+
units = [32, 64, 128]
28+
mlp = MLPBlock(units, normalization="batchnorm")
29+
assert isinstance(mlp, MLPBlock)
30+
for i, module in enumerate(mlp[1:]):
31+
if (i + 1) % 3 == 0:
32+
assert isinstance(module, nn.LazyBatchNorm1d)
33+
34+
def test_normalization_custom(self):
35+
units = [32, 64, 128]
36+
custom_norm = nn.LayerNorm(1)
37+
mlp = MLPBlock(units, normalization=custom_norm)
38+
assert isinstance(mlp, MLPBlock)
39+
for i, module in enumerate(mlp[1:]):
40+
if i % 3 == 2:
41+
assert isinstance(module, nn.LayerNorm)
42+
43+
def test_normalization_invalid(self):
44+
units = [32, 64, 128]
45+
with pytest.raises(ValueError):
46+
MLPBlock(units, normalization="invalid")
47+
48+
def test_dropout_float(self):
49+
units = [32, 64, 128]
50+
mlp = MLPBlock(units, dropout=0.5)
51+
assert isinstance(mlp, MLPBlock)
52+
for i, module in enumerate(mlp[1:]):
53+
if i % 3 == 2:
54+
assert isinstance(module, nn.Dropout)
55+
assert module.p == 0.5
56+
57+
def test_dropout_module(self):
58+
units = [32, 64, 128]
59+
mlp = MLPBlock(units, dropout=nn.Dropout(0.5))
60+
assert isinstance(mlp, MLPBlock)
61+
for i, module in enumerate(mlp[1:]):
62+
if i % 3 == 2:
63+
assert isinstance(module, nn.Dropout)
64+
assert module.p == 0.5
65+
66+
def test_forward(self):
67+
mlp = MLPBlock([32])
68+
inputs = {"a": torch.randn(32, 2), "b": torch.randn(32, 2)}
69+
outputs = module_utils.module_test(mlp, inputs)
70+
assert outputs.shape == torch.Size([32, 32])

tests/unit/torch/utils/test_module_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
#
1616

17-
from typing import Dict
17+
from typing import Dict, Union
1818

1919
import pytest
2020
import torch
@@ -34,6 +34,16 @@ def forward(self, x: Dict[str, torch.Tensor]):
3434
pass
3535

3636

37+
class ModuleWithDictUnion(nn.Module):
38+
def forward(self, x: Union[Dict[str, torch.Tensor], torch.Tensor]):
39+
pass
40+
41+
42+
class ModuleWithDictUnion2(nn.Module):
43+
def forward(self, x: Union[torch.Tensor, Dict[str, torch.Tensor]]):
44+
pass
45+
46+
3747
class ModuleWithBatch(nn.Module):
3848
def forward(self, x, batch=None):
3949
pass
@@ -57,6 +67,13 @@ def test_basic(self):
5767
assert is_tabular(module_with_dict)
5868
assert not is_tabular(module_without_dict)
5969

70+
def test_union(self):
71+
module_with_dict_union = ModuleWithDictUnion()
72+
module_with_dict_union2 = ModuleWithDictUnion2()
73+
74+
assert is_tabular(module_with_dict_union)
75+
assert is_tabular(module_with_dict_union2)
76+
6077

6178
class Test_check_batch_arg:
6279
def test_basic(self):

0 commit comments

Comments
 (0)