Skip to content

Commit 0d418f8

Browse files
committed
Adding Link
1 parent 9696dd7 commit 0d418f8

File tree

7 files changed

+266
-30
lines changed

7 files changed

+266
-30
lines changed

merlin/models/torch/block.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from merlin.models.torch.batch import Batch
2424
from merlin.models.torch.container import BlockContainer, BlockContainerDict
25+
from merlin.models.torch.link import Link, LinkType
2526
from merlin.models.torch.registry import registry
2627
from merlin.models.utils.registry import RegistryMixin
2728

@@ -65,7 +66,7 @@ def forward(
6566

6667
return inputs
6768

68-
def repeat(self, n: int = 1, name=None) -> "Block":
69+
def repeat(self, n: int = 1, link: Optional[LinkType] = None, name=None) -> "Block":
6970
"""
7071
Creates a new block by repeating the current block `n` times.
7172
Each repetition is a deep copy of the current block.
@@ -89,6 +90,9 @@ def repeat(self, n: int = 1, name=None) -> "Block":
8990
raise ValueError("n must be greater than 0")
9091

9192
repeats = [self.copy() for _ in range(n - 1)]
93+
if link:
94+
parsed_link = Link.parse(link)
95+
repeats = [parsed_link.copy().setup_link(repeat) for repeat in repeats]
9296

9397
return Block(self, *repeats, name=name)
9498

@@ -152,33 +156,33 @@ def forward(
152156

153157
return outputs
154158

155-
def append(self, module: nn.Module):
156-
self.post.append(module)
159+
def append(self, module: nn.Module, link: Optional[LinkType] = None):
160+
self.post.append(module, link=link)
157161

158162
return self
159163

160-
def prepend(self, module: nn.Module):
161-
self.pre.prepend(module)
164+
def prepend(self, module: nn.Module, link: Optional[LinkType] = None):
165+
self.pre.prepend(module, link=link)
162166

163167
return self
164168

165-
def append_to(self, name: str, module: nn.Module):
166-
self.branches[name].append(module)
169+
def append_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None):
170+
self.branches[name].append(module, link=link)
167171

168172
return self
169173

170-
def prepend_to(self, name: str, module: nn.Module):
171-
self.branches[name].prepend(module)
174+
def prepend_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None):
175+
self.branches[name].prepend(module, link=link)
172176

173177
return self
174178

175-
def append_for_each(self, module: nn.Module, shared=False):
176-
self.branches.append_for_each(module, shared=shared)
179+
def append_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None):
180+
self.branches.append_for_each(module, shared=shared, link=link)
177181

178182
return self
179183

180-
def prepend_for_each(self, module: nn.Module, shared=False):
181-
self.branches.prepend_for_each(module, shared=shared)
184+
def prepend_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None):
185+
self.branches.prepend_for_each(module, shared=shared, link=link)
182186

183187
return self
184188

merlin/models/torch/container.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torch import nn
2222
from torch._jit_internal import _copy_to_script_wrapper
2323

24+
from merlin.models.torch.link import Link, LinkType
2425
from merlin.models.torch.utils import torchscript_utils
2526

2627

@@ -46,7 +47,7 @@ def __init__(self, *inputs: nn.Module, name: Optional[str] = None):
4647

4748
self._name: str = name
4849

49-
def append(self, module: nn.Module):
50+
def append(self, module: nn.Module, link: Optional[Link] = None):
5051
"""Appends a given module to the end of the list.
5152
5253
Parameters
@@ -58,11 +59,12 @@ def append(self, module: nn.Module):
5859
-------
5960
self
6061
"""
61-
self.values.append(self.wrap_module(module))
62+
_module = self._check_link(module, link=link)
63+
self.values.append(self.wrap_module(_module))
6264

6365
return self
6466

65-
def prepend(self, module: nn.Module):
67+
def prepend(self, module: nn.Module, link: Optional[Link] = None):
6668
"""Prepends a given module to the beginning of the list.
6769
6870
Parameters
@@ -74,9 +76,9 @@ def prepend(self, module: nn.Module):
7476
-------
7577
self
7678
"""
77-
return self.insert(0, module)
79+
return self.insert(0, module, link=link)
7880

79-
def insert(self, index: int, module: nn.Module):
81+
def insert(self, index: int, module: nn.Module, link: Optional[Link] = None):
8082
"""Inserts a given module at the specified index.
8183
8284
Parameters
@@ -90,8 +92,8 @@ def insert(self, index: int, module: nn.Module):
9092
-------
9193
self
9294
"""
93-
94-
self.values.insert(index, self.wrap_module(module))
95+
_module = self._check_link(module, link=link)
96+
self.values.insert(index, self.wrap_module(_module))
9597

9698
return self
9799

@@ -152,6 +154,15 @@ def __repr__(self) -> str:
152154
def _get_name(self) -> str:
153155
return super()._get_name() if self._name is None else self._name
154156

157+
def _check_link(self, module: nn.Module, link: Optional[LinkType] = None) -> nn.Module:
158+
if link:
159+
linked_module: Link = Link.parse(link)
160+
linked_module.setup_link(module)
161+
162+
return linked_module
163+
164+
return module
165+
155166

156167
class BlockContainerDict(nn.ModuleDict):
157168
def __init__(
@@ -166,28 +177,36 @@ def __init__(
166177
super().__init__(modules)
167178
self._name: str = name
168179

169-
def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
170-
self._modules[name].append(module)
180+
def append_to(
181+
self, name: str, module: nn.Module, link: Optional[LinkType] = None
182+
) -> "BlockContainerDict":
183+
self._modules[name].append(module, link=link)
171184

172185
return self
173186

174-
def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
175-
self._modules[name].prepend(module)
187+
def prepend_to(
188+
self, name: str, module: nn.Module, link: Optional[LinkType] = None
189+
) -> "BlockContainerDict":
190+
self._modules[name].prepend(module, link=link)
176191

177192
return self
178193

179194
# Append to all branches, optionally copying
180-
def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict":
195+
def append_for_each(
196+
self, module: nn.Module, shared=False, link: Optional[LinkType] = None
197+
) -> "BlockContainerDict":
181198
for branch in self.values():
182199
_module = module if shared else deepcopy(module)
183-
branch.append(_module)
200+
branch.append(_module, link=link)
184201

185202
return self
186203

187-
def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict":
204+
def prepend_for_each(
205+
self, module: nn.Module, shared=False, link: Optional[LinkType] = None
206+
) -> "BlockContainerDict":
188207
for branch in self.values():
189208
_module = module if shared else deepcopy(module)
190-
branch.prepend(_module)
209+
branch.prepend(_module, link=link)
191210

192211
return self
193212

merlin/models/torch/link.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import copy
2+
from typing import Dict, Optional, Union
3+
4+
import torch
5+
from torch import nn
6+
7+
from merlin.models.torch.registry import TorchRegistryMixin
8+
9+
LinkType = Union[str, "Link"]
10+
11+
12+
class Link(nn.Module, TorchRegistryMixin):
13+
"""Base class for different types of network links.
14+
15+
This is typically used as part of a `Block` to connect different modules.
16+
17+
Some examples of links are:
18+
- `residual`: Adds the input to the output of the module.
19+
- `shortcut`: Outputs a dictionary with the output of the module and the input.
20+
- `shortcut-concat`: Concatenates the input and the output of the module.
21+
22+
"""
23+
24+
def __init__(self, output: Optional[nn.Module] = None):
25+
super().__init__()
26+
27+
if output is not None:
28+
self.setup_link(output)
29+
30+
def setup_link(self, output: nn.Module) -> "Link":
31+
"""
32+
Setup function for the link.
33+
34+
Parameters
35+
----------
36+
output : nn.Module
37+
The output module for the link.
38+
39+
Returns
40+
-------
41+
Link
42+
The updated Link instance.
43+
"""
44+
45+
self.output = output
46+
47+
return self
48+
49+
def copy(self) -> "Link":
50+
"""
51+
Returns a copy of the link.
52+
53+
Returns
54+
-------
55+
Link
56+
The copied link.
57+
"""
58+
return copy.deepcopy(self)
59+
60+
61+
@Link.registry.register("residual")
62+
class Residual(Link):
63+
"""Adds the input to the output of the module."""
64+
65+
def forward(self, x: torch.Tensor) -> torch.Tensor:
66+
return x + self.output(x)
67+
68+
69+
@Link.registry.register("shortcut")
70+
class Shortcut(Link):
71+
"""Outputs a dictionary with the output of the module and the input."""
72+
73+
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
74+
return {"output": self.output(x), "shortcut": x}
75+
76+
77+
@Link.registry.register("shortcut-concat")
78+
class ShortcutConcat(Link):
79+
"""Concatenates the input and the output of the module."""
80+
81+
def forward(self, x: torch.Tensor) -> torch.Tensor:
82+
intermediate_output = self.output(x)
83+
return torch.cat((x, intermediate_output), dim=1)

merlin/models/torch/registry.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1-
from merlin.models.utils.registry import Registry
1+
from merlin.models.utils.registry import Registry, RegistryMixin
22

33
registry: Registry = Registry.class_registry("modules")
4+
5+
6+
class TorchRegistryMixin(RegistryMixin):
7+
registry = registry
8+
9+
10+
__all__ = ["registry", "Registry", "RegistryMixin", "TorchRegistryMixin"]

tests/unit/torch/test_block.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
from torch import nn
2121

22+
from merlin.models.torch import link
2223
from merlin.models.torch.batch import Batch
2324
from merlin.models.torch.block import Block, ParallelBlock
2425
from merlin.models.torch.container import BlockContainer, BlockContainerDict
@@ -59,6 +60,9 @@ def test_insertion(self):
5960

6061
assert torch.equal(outputs, inputs + 2)
6162

63+
block.append(PlusOne(), link="residual")
64+
assert isinstance(block[-1], link.Residual)
65+
6266
def test_copy(self):
6367
block = Block(PlusOne())
6468

@@ -82,6 +86,19 @@ def test_repeat(self):
8286
with pytest.raises(ValueError, match="n must be greater than 0"):
8387
block.repeat(0)
8488

89+
def test_repeat_with_link(self):
90+
block = Block(PlusOne())
91+
92+
repeated = block.repeat(2, link="residual")
93+
assert isinstance(repeated, Block)
94+
assert len(repeated) == 2
95+
assert isinstance(repeated[-1], link.Residual)
96+
97+
inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
98+
outputs = module_utils.module_test(repeated, inputs)
99+
100+
assert torch.equal(outputs, (inputs + 1) + (inputs + 1) + 1)
101+
85102
def test_from_registry(self):
86103
@Block.registry.register("my_block")
87104
class MyBlock(Block):

0 commit comments

Comments
 (0)