2121from torch import nn
2222from torch ._jit_internal import _copy_to_script_wrapper
2323
24+ from merlin .models .torch .link import Link , LinkType
2425from merlin .models .torch .utils import torchscript_utils
2526
2627
@@ -46,7 +47,7 @@ def __init__(self, *inputs: nn.Module, name: Optional[str] = None):
4647
4748 self ._name : str = name
4849
49- def append (self , module : nn .Module ):
50+ def append (self , module : nn .Module , link : Optional [ Link ] = None ):
5051 """Appends a given module to the end of the list.
5152
5253 Parameters
@@ -58,11 +59,12 @@ def append(self, module: nn.Module):
5859 -------
5960 self
6061 """
61- self .values .append (self .wrap_module (module ))
62+ _module = self ._check_link (module , link = link )
63+ self .values .append (self .wrap_module (_module ))
6264
6365 return self
6466
65- def prepend (self , module : nn .Module ):
67+ def prepend (self , module : nn .Module , link : Optional [ Link ] = None ):
6668 """Prepends a given module to the beginning of the list.
6769
6870 Parameters
@@ -74,9 +76,9 @@ def prepend(self, module: nn.Module):
7476 -------
7577 self
7678 """
77- return self .insert (0 , module )
79+ return self .insert (0 , module , link = link )
7880
79- def insert (self , index : int , module : nn .Module ):
81+ def insert (self , index : int , module : nn .Module , link : Optional [ Link ] = None ):
8082 """Inserts a given module at the specified index.
8183
8284 Parameters
@@ -90,8 +92,8 @@ def insert(self, index: int, module: nn.Module):
9092 -------
9193 self
9294 """
93-
94- self .values .insert (index , self .wrap_module (module ))
95+ _module = self . _check_link ( module , link = link )
96+ self .values .insert (index , self .wrap_module (_module ))
9597
9698 return self
9799
@@ -152,6 +154,15 @@ def __repr__(self) -> str:
152154 def _get_name (self ) -> str :
153155 return super ()._get_name () if self ._name is None else self ._name
154156
157+ def _check_link (self , module : nn .Module , link : Optional [LinkType ] = None ) -> nn .Module :
158+ if link :
159+ linked_module : Link = Link .parse (link )
160+ linked_module .setup_link (module )
161+
162+ return linked_module
163+
164+ return module
165+
155166
156167class BlockContainerDict (nn .ModuleDict ):
157168 def __init__ (
@@ -166,28 +177,36 @@ def __init__(
166177 super ().__init__ (modules )
167178 self ._name : str = name
168179
169- def append_to (self , name : str , module : nn .Module ) -> "BlockContainerDict" :
170- self ._modules [name ].append (module )
180+ def append_to (
181+ self , name : str , module : nn .Module , link : Optional [LinkType ] = None
182+ ) -> "BlockContainerDict" :
183+ self ._modules [name ].append (module , link = link )
171184
172185 return self
173186
174- def prepend_to (self , name : str , module : nn .Module ) -> "BlockContainerDict" :
175- self ._modules [name ].prepend (module )
187+ def prepend_to (
188+ self , name : str , module : nn .Module , link : Optional [LinkType ] = None
189+ ) -> "BlockContainerDict" :
190+ self ._modules [name ].prepend (module , link = link )
176191
177192 return self
178193
179194 # Append to all branches, optionally copying
180- def append_for_each (self , module : nn .Module , shared = False ) -> "BlockContainerDict" :
195+ def append_for_each (
196+ self , module : nn .Module , shared = False , link : Optional [LinkType ] = None
197+ ) -> "BlockContainerDict" :
181198 for branch in self .values ():
182199 _module = module if shared else deepcopy (module )
183- branch .append (_module )
200+ branch .append (_module , link = link )
184201
185202 return self
186203
187- def prepend_for_each (self , module : nn .Module , shared = False ) -> "BlockContainerDict" :
204+ def prepend_for_each (
205+ self , module : nn .Module , shared = False , link : Optional [LinkType ] = None
206+ ) -> "BlockContainerDict" :
188207 for branch in self .values ():
189208 _module = module if shared else deepcopy (module )
190- branch .prepend (_module )
209+ branch .prepend (_module , link = link )
191210
192211 return self
193212
0 commit comments