2121from torch import nn
2222from torch ._jit_internal import _copy_to_script_wrapper
2323
24+ from merlin .models .torch .link import Link , LinkType
2425from merlin .models .torch .utils import torchscript_utils
2526
2627
@@ -46,37 +47,52 @@ def __init__(self, *inputs: nn.Module, name: Optional[str] = None):
4647
4748 self ._name : str = name
4849
49- def append (self , module : nn .Module ):
50+ def append (self , module : nn .Module , link : Optional [ Link ] = None ):
5051 """Appends a given module to the end of the list.
5152
5253 Parameters
5354 ----------
5455 module : nn.Module
5556 The PyTorch module to be appended.
57+ link : Optional[LinkType]
58+ The link to use for the module. If None, no link is used.
59+ This can either be a Module or a string, options are:
60+ - "residual": Adds a residual connection to the module.
61+ - "shortcut": Adds a shortcut connection to the module.
62+ - "shortcut-concat": Adds a shortcut connection by concatenating
63+ the input and output.
5664
5765 Returns
5866 -------
5967 self
6068 """
61- self .values .append (self .wrap_module (module ))
69+ _module = self ._check_link (module , link = link )
70+ self .values .append (self .wrap_module (_module ))
6271
6372 return self
6473
65- def prepend (self , module : nn .Module ):
74+ def prepend (self , module : nn .Module , link : Optional [ Link ] = None ):
6675 """Prepends a given module to the beginning of the list.
6776
6877 Parameters
6978 ----------
7079 module : nn.Module
7180 The PyTorch module to be prepended.
81+ link : Optional[LinkType]
82+ The link to use for the module. If None, no link is used.
83+ This can either be a Module or a string, options are:
84+ - "residual": Adds a residual connection to the module.
85+ - "shortcut": Adds a shortcut connection to the module.
86+ - "shortcut-concat": Adds a shortcut connection by concatenating
87+ the input and output.
7288
7389 Returns
7490 -------
7591 self
7692 """
77- return self .insert (0 , module )
93+ return self .insert (0 , module , link = link )
7894
79- def insert (self , index : int , module : nn .Module ):
95+ def insert (self , index : int , module : nn .Module , link : Optional [ Link ] = None ):
8096 """Inserts a given module at the specified index.
8197
8298 Parameters
@@ -85,13 +101,20 @@ def insert(self, index: int, module: nn.Module):
85101 The index at which the module is to be inserted.
86102 module : nn.Module
87103 The PyTorch module to be inserted.
104+ link : Optional[LinkType]
105+ The link to use for the module. If None, no link is used.
106+ This can either be a Module or a string, options are:
107+ - "residual": Adds a residual connection to the module.
108+ - "shortcut": Adds a shortcut connection to the module.
109+ - "shortcut-concat": Adds a shortcut connection by concatenating
110+ the input and output.
88111
89112 Returns
90113 -------
91114 self
92115 """
93-
94- self .values .insert (index , self .wrap_module (module ))
116+ _module = self . _check_link ( module , link = link )
117+ self .values .insert (index , self .wrap_module (_module ))
95118
96119 return self
97120
@@ -152,6 +175,15 @@ def __repr__(self) -> str:
152175 def _get_name (self ) -> str :
153176 return super ()._get_name () if self ._name is None else self ._name
154177
178+ def _check_link (self , module : nn .Module , link : Optional [LinkType ] = None ) -> nn .Module :
179+ if link :
180+ linked_module : Link = Link .parse (link )
181+ linked_module .setup_link (module )
182+
183+ return linked_module
184+
185+ return module
186+
155187
156188class BlockContainerDict (nn .ModuleDict ):
157189 """A container class for PyTorch `nn.Module` that allows for manipulation and traversal
@@ -178,7 +210,9 @@ def __init__(
178210 super ().__init__ (modules )
179211 self ._name : str = name
180212
181- def append_to (self , name : str , module : nn .Module ) -> "BlockContainerDict" :
213+ def append_to (
214+ self , name : str , module : nn .Module , link : Optional [LinkType ] = None
215+ ) -> "BlockContainerDict" :
182216 """Appends a module to a specified name.
183217
184218 Parameters
@@ -187,18 +221,27 @@ def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
187221 The name of the branch.
188222 module : nn.Module
189223 The module to append.
224+ link : Optional[LinkType]
225+ The link to use for the module. If None, no link is used.
226+ This can either be a Module or a string, options are:
227+ - "residual": Adds a residual connection to the module.
228+ - "shortcut": Adds a shortcut connection to the module.
229+ - "shortcut-concat": Adds a shortcut connection by concatenating
230+ the input and output.
190231
191232 Returns
192233 -------
193234 BlockContainerDict
194235 The current object itself.
195236 """
196237
197- self ._modules [name ].append (module )
238+ self ._modules [name ].append (module , link = link )
198239
199240 return self
200241
201- def prepend_to (self , name : str , module : nn .Module ) -> "BlockContainerDict" :
242+ def prepend_to (
243+ self , name : str , module : nn .Module , link : Optional [LinkType ] = None
244+ ) -> "BlockContainerDict" :
202245 """Prepends a module to a specified name.
203246
204247 Parameters
@@ -207,19 +250,25 @@ def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
207250 The name of the branch.
208251 module : nn.Module
209252 The module to prepend.
253+ link : Optional[LinkType]
254+ The link to use for the module. If None, no link is used.
255+ This can either be a Module or a string, options are:
256+ - "residual": Adds a residual connection to the module.
257+ - "shortcut": Adds a shortcut connection to the module.
258+ - "shortcut-concat": Adds a shortcut connection by concatenating
259+ the input and output.
210260
211261 Returns
212262 -------
213263 BlockContainerDict
214264 The current object itself.
215265 """
216266
217- self ._modules [name ].prepend (module )
267+ self ._modules [name ].prepend (module , link = link )
218268
219- return self
220-
221- # Append to all branches, optionally copying
222- def append_for_each (self , module : nn .Module , shared = False ) -> "BlockContainerDict" :
269+ def append_for_each (
270+ self , module : nn .Module , shared = False , link : Optional [LinkType ] = None
271+ ) -> "BlockContainerDict" :
223272 """Appends a module to each branch.
224273
225274 Parameters
@@ -229,6 +278,13 @@ def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDic
229278 shared : bool, default=False
230279 If True, the same module is shared across all elements.
231280 Otherwise a deep copy of the module is used in each element.
281+ link : Optional[LinkType]
282+ The link to use for the module. If None, no link is used.
283+ This can either be a Module or a string, options are:
284+ - "residual": Adds a residual connection to the module.
285+ - "shortcut": Adds a shortcut connection to the module.
286+ - "shortcut-concat": Adds a shortcut connection by concatenating
287+ the input and output.
232288
233289 Returns
234290 -------
@@ -238,11 +294,13 @@ def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDic
238294
239295 for branch in self .values ():
240296 _module = module if shared else deepcopy (module )
241- branch .append (_module )
297+ branch .append (_module , link = link )
242298
243299 return self
244300
245- def prepend_for_each (self , module : nn .Module , shared = False ) -> "BlockContainerDict" :
301+ def prepend_for_each (
302+ self , module : nn .Module , shared = False , link : Optional [LinkType ] = None
303+ ) -> "BlockContainerDict" :
246304 """Prepends a module to each branch.
247305
248306 Parameters
@@ -252,23 +310,25 @@ def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDi
252310 shared : bool, default=False
253311 If True, the same module is shared across all elements.
254312 Otherwise a deep copy of the module is used in each element.
313+ link : Optional[LinkType]
314+ The link to use for the module. If None, no link is used.
315+ This can either be a Module or a string, options are:
316+ - "residual": Adds a residual connection to the module.
317+ - "shortcut": Adds a shortcut connection to the module.
318+ - "shortcut-concat": Adds a shortcut connection by concatenating
319+ the input and output.
255320
256321 Returns
257322 -------
258323 BlockContainerDict
259324 The current object itself.
260325 """
261-
262326 for branch in self .values ():
263327 _module = module if shared else deepcopy (module )
264- branch .prepend (_module )
328+ branch .prepend (_module , link = link )
265329
266330 return self
267331
268- # This assumes same branches, we append it's content to each branch
269- # def append_parallel(self, module: "BlockContainerDict") -> "BlockContainerDict":
270- # ...
271-
272332 def add_module (self , name : str , module : Optional [nn .Module ]) -> None :
273333 if module and not isinstance (module , BlockContainer ):
274334 module = BlockContainer (module , name = name [0 ].upper () + name [1 :])
0 commit comments