Skip to content

Commit 9103781

Browse files
marcromeynedknv
andauthored
Adding Link (#1091)
* Adding Link * Fixing doc-strings --------- Co-authored-by: edknv <109497216+edknv@users.noreply.github.com>
1 parent 2dac821 commit 9103781

File tree

7 files changed

+313
-49
lines changed

7 files changed

+313
-49
lines changed

merlin/models/torch/block.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from merlin.models.torch.batch import Batch
2424
from merlin.models.torch.container import BlockContainer, BlockContainerDict
25+
from merlin.models.torch.link import Link, LinkType
2526
from merlin.models.torch.registry import registry
2627
from merlin.models.utils.registry import RegistryMixin
2728

@@ -65,7 +66,7 @@ def forward(
6566

6667
return inputs
6768

68-
def repeat(self, n: int = 1, name=None) -> "Block":
69+
def repeat(self, n: int = 1, link: Optional[LinkType] = None, name=None) -> "Block":
6970
"""
7071
Creates a new block by repeating the current block `n` times.
7172
Each repetition is a deep copy of the current block.
@@ -89,6 +90,9 @@ def repeat(self, n: int = 1, name=None) -> "Block":
8990
raise ValueError("n must be greater than 0")
9091

9192
repeats = [self.copy() for _ in range(n - 1)]
93+
if link:
94+
parsed_link = Link.parse(link)
95+
repeats = [parsed_link.copy().setup_link(repeat) for repeat in repeats]
9296

9397
return Block(self, *repeats, name=name)
9498

@@ -205,7 +209,7 @@ def forward(
205209

206210
return outputs
207211

208-
def append(self, module: nn.Module):
212+
def append(self, module: nn.Module, link: Optional[LinkType] = None):
209213
"""Appends a module to the post-processing stage.
210214
211215
Parameters
@@ -219,29 +223,16 @@ def append(self, module: nn.Module):
219223
The current object itself.
220224
"""
221225

222-
self.post.append(module)
226+
self.post.append(module, link=link)
223227

224228
return self
225229

226230
def prepend(self, module: nn.Module):
227-
"""Prepends a module to the pre-processing stage.
228-
229-
Parameters
230-
----------
231-
module : nn.Module
232-
The module to prepend.
233-
234-
Returns
235-
-------
236-
ParallelBlock
237-
The current object itself.
238-
"""
239-
240231
self.pre.prepend(module)
241232

242233
return self
243234

244-
def append_to(self, name: str, module: nn.Module):
235+
def append_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None):
245236
"""Appends a module to a specified branch.
246237
247238
Parameters
@@ -257,11 +248,11 @@ def append_to(self, name: str, module: nn.Module):
257248
The current object itself.
258249
"""
259250

260-
self.branches[name].append(module)
251+
self.branches[name].append(module, link=link)
261252

262253
return self
263254

264-
def prepend_to(self, name: str, module: nn.Module):
255+
def prepend_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None):
265256
"""Prepends a module to a specified branch.
266257
267258
Parameters
@@ -276,11 +267,11 @@ def prepend_to(self, name: str, module: nn.Module):
276267
ParallelBlock
277268
The current object itself.
278269
"""
279-
self.branches[name].prepend(module)
270+
self.branches[name].prepend(module, link=link)
280271

281272
return self
282273

283-
def append_for_each(self, module: nn.Module, shared=False):
274+
def append_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None):
284275
"""Appends a module to each branch.
285276
286277
Parameters
@@ -297,11 +288,11 @@ def append_for_each(self, module: nn.Module, shared=False):
297288
The current object itself.
298289
"""
299290

300-
self.branches.append_for_each(module, shared=shared)
291+
self.branches.append_for_each(module, shared=shared, link=link)
301292

302293
return self
303294

304-
def prepend_for_each(self, module: nn.Module, shared=False):
295+
def prepend_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None):
305296
"""Prepends a module to each branch.
306297
307298
Parameters
@@ -318,7 +309,7 @@ def prepend_for_each(self, module: nn.Module, shared=False):
318309
The current object itself.
319310
"""
320311

321-
self.branches.prepend_for_each(module, shared=shared)
312+
self.branches.prepend_for_each(module, shared=shared, link=link)
322313

323314
return self
324315

merlin/models/torch/container.py

Lines changed: 83 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torch import nn
2222
from torch._jit_internal import _copy_to_script_wrapper
2323

24+
from merlin.models.torch.link import Link, LinkType
2425
from merlin.models.torch.utils import torchscript_utils
2526

2627

@@ -46,37 +47,52 @@ def __init__(self, *inputs: nn.Module, name: Optional[str] = None):
4647

4748
self._name: str = name
4849

49-
def append(self, module: nn.Module):
50+
def append(self, module: nn.Module, link: Optional[Link] = None):
5051
"""Appends a given module to the end of the list.
5152
5253
Parameters
5354
----------
5455
module : nn.Module
5556
The PyTorch module to be appended.
57+
link : Optional[LinkType]
58+
The link to use for the module. If None, no link is used.
59+
This can either be a Module or a string, options are:
60+
- "residual": Adds a residual connection to the module.
61+
- "shortcut": Adds a shortcut connection to the module.
62+
- "shortcut-concat": Adds a shortcut connection by concatenating
63+
the input and output.
5664
5765
Returns
5866
-------
5967
self
6068
"""
61-
self.values.append(self.wrap_module(module))
69+
_module = self._check_link(module, link=link)
70+
self.values.append(self.wrap_module(_module))
6271

6372
return self
6473

65-
def prepend(self, module: nn.Module):
74+
def prepend(self, module: nn.Module, link: Optional[Link] = None):
6675
"""Prepends a given module to the beginning of the list.
6776
6877
Parameters
6978
----------
7079
module : nn.Module
7180
The PyTorch module to be prepended.
81+
link : Optional[LinkType]
82+
The link to use for the module. If None, no link is used.
83+
This can either be a Module or a string, options are:
84+
- "residual": Adds a residual connection to the module.
85+
- "shortcut": Adds a shortcut connection to the module.
86+
- "shortcut-concat": Adds a shortcut connection by concatenating
87+
the input and output.
7288
7389
Returns
7490
-------
7591
self
7692
"""
77-
return self.insert(0, module)
93+
return self.insert(0, module, link=link)
7894

79-
def insert(self, index: int, module: nn.Module):
95+
def insert(self, index: int, module: nn.Module, link: Optional[Link] = None):
8096
"""Inserts a given module at the specified index.
8197
8298
Parameters
@@ -85,13 +101,20 @@ def insert(self, index: int, module: nn.Module):
85101
The index at which the module is to be inserted.
86102
module : nn.Module
87103
The PyTorch module to be inserted.
104+
link : Optional[LinkType]
105+
The link to use for the module. If None, no link is used.
106+
This can either be a Module or a string, options are:
107+
- "residual": Adds a residual connection to the module.
108+
- "shortcut": Adds a shortcut connection to the module.
109+
- "shortcut-concat": Adds a shortcut connection by concatenating
110+
the input and output.
88111
89112
Returns
90113
-------
91114
self
92115
"""
93-
94-
self.values.insert(index, self.wrap_module(module))
116+
_module = self._check_link(module, link=link)
117+
self.values.insert(index, self.wrap_module(_module))
95118

96119
return self
97120

@@ -152,6 +175,15 @@ def __repr__(self) -> str:
152175
def _get_name(self) -> str:
153176
return super()._get_name() if self._name is None else self._name
154177

178+
def _check_link(self, module: nn.Module, link: Optional[LinkType] = None) -> nn.Module:
179+
if link:
180+
linked_module: Link = Link.parse(link)
181+
linked_module.setup_link(module)
182+
183+
return linked_module
184+
185+
return module
186+
155187

156188
class BlockContainerDict(nn.ModuleDict):
157189
"""A container class for PyTorch `nn.Module` that allows for manipulation and traversal
@@ -178,7 +210,9 @@ def __init__(
178210
super().__init__(modules)
179211
self._name: str = name
180212

181-
def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
213+
def append_to(
214+
self, name: str, module: nn.Module, link: Optional[LinkType] = None
215+
) -> "BlockContainerDict":
182216
"""Appends a module to a specified name.
183217
184218
Parameters
@@ -187,18 +221,27 @@ def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
187221
The name of the branch.
188222
module : nn.Module
189223
The module to append.
224+
link : Optional[LinkType]
225+
The link to use for the module. If None, no link is used.
226+
This can either be a Module or a string, options are:
227+
- "residual": Adds a residual connection to the module.
228+
- "shortcut": Adds a shortcut connection to the module.
229+
- "shortcut-concat": Adds a shortcut connection by concatenating
230+
the input and output.
190231
191232
Returns
192233
-------
193234
BlockContainerDict
194235
The current object itself.
195236
"""
196237

197-
self._modules[name].append(module)
238+
self._modules[name].append(module, link=link)
198239

199240
return self
200241

201-
def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
242+
def prepend_to(
243+
self, name: str, module: nn.Module, link: Optional[LinkType] = None
244+
) -> "BlockContainerDict":
202245
"""Prepends a module to a specified name.
203246
204247
Parameters
@@ -207,19 +250,25 @@ def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
207250
The name of the branch.
208251
module : nn.Module
209252
The module to prepend.
253+
link : Optional[LinkType]
254+
The link to use for the module. If None, no link is used.
255+
This can either be a Module or a string, options are:
256+
- "residual": Adds a residual connection to the module.
257+
- "shortcut": Adds a shortcut connection to the module.
258+
- "shortcut-concat": Adds a shortcut connection by concatenating
259+
the input and output.
210260
211261
Returns
212262
-------
213263
BlockContainerDict
214264
The current object itself.
215265
"""
216266

217-
self._modules[name].prepend(module)
267+
self._modules[name].prepend(module, link=link)
218268

219-
return self
220-
221-
# Append to all branches, optionally copying
222-
def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict":
269+
def append_for_each(
270+
self, module: nn.Module, shared=False, link: Optional[LinkType] = None
271+
) -> "BlockContainerDict":
223272
"""Appends a module to each branch.
224273
225274
Parameters
@@ -229,6 +278,13 @@ def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDic
229278
shared : bool, default=False
230279
If True, the same module is shared across all elements.
231280
Otherwise a deep copy of the module is used in each element.
281+
link : Optional[LinkType]
282+
The link to use for the module. If None, no link is used.
283+
This can either be a Module or a string, options are:
284+
- "residual": Adds a residual connection to the module.
285+
- "shortcut": Adds a shortcut connection to the module.
286+
- "shortcut-concat": Adds a shortcut connection by concatenating
287+
the input and output.
232288
233289
Returns
234290
-------
@@ -238,11 +294,13 @@ def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDic
238294

239295
for branch in self.values():
240296
_module = module if shared else deepcopy(module)
241-
branch.append(_module)
297+
branch.append(_module, link=link)
242298

243299
return self
244300

245-
def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict":
301+
def prepend_for_each(
302+
self, module: nn.Module, shared=False, link: Optional[LinkType] = None
303+
) -> "BlockContainerDict":
246304
"""Prepends a module to each branch.
247305
248306
Parameters
@@ -252,23 +310,25 @@ def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDi
252310
shared : bool, default=False
253311
If True, the same module is shared across all elements.
254312
Otherwise a deep copy of the module is used in each element.
313+
link : Optional[LinkType]
314+
The link to use for the module. If None, no link is used.
315+
This can either be a Module or a string, options are:
316+
- "residual": Adds a residual connection to the module.
317+
- "shortcut": Adds a shortcut connection to the module.
318+
- "shortcut-concat": Adds a shortcut connection by concatenating
319+
the input and output.
255320
256321
Returns
257322
-------
258323
BlockContainerDict
259324
The current object itself.
260325
"""
261-
262326
for branch in self.values():
263327
_module = module if shared else deepcopy(module)
264-
branch.prepend(_module)
328+
branch.prepend(_module, link=link)
265329

266330
return self
267331

268-
# This assumes same branches, we append it's content to each branch
269-
# def append_parallel(self, module: "BlockContainerDict") -> "BlockContainerDict":
270-
# ...
271-
272332
def add_module(self, name: str, module: Optional[nn.Module]) -> None:
273333
if module and not isinstance(module, BlockContainer):
274334
module = BlockContainer(module, name=name[0].upper() + name[1:])

0 commit comments

Comments
 (0)