@@ -54,6 +54,13 @@ def append(self, module: nn.Module, link: Optional[Link] = None):
5454 ----------
5555 module : nn.Module
5656 The PyTorch module to be appended.
57+ link : Optional[LinkType]
58+ The link to use for the module. If None, no link is used.
59+ This can either be a Module or a string, options are:
60+ - "residual": Adds a residual connection to the module.
61+ - "shortcut": Adds a shortcut connection to the module.
62+ - "shortcut-concat": Adds a shortcut connection by concatenating
63+ the input and output.
5764
5865 Returns
5966 -------
@@ -71,6 +78,13 @@ def prepend(self, module: nn.Module, link: Optional[Link] = None):
7178 ----------
7279 module : nn.Module
7380 The PyTorch module to be prepended.
81+ link : Optional[LinkType]
82+ The link to use for the module. If None, no link is used.
83+ This can either be a Module or a string, options are:
84+ - "residual": Adds a residual connection to the module.
85+ - "shortcut": Adds a shortcut connection to the module.
86+ - "shortcut-concat": Adds a shortcut connection by concatenating
87+ the input and output.
7488
7589 Returns
7690 -------
@@ -87,6 +101,13 @@ def insert(self, index: int, module: nn.Module, link: Optional[Link] = None):
87101 The index at which the module is to be inserted.
88102 module : nn.Module
89103 The PyTorch module to be inserted.
104+ link : Optional[LinkType]
105+ The link to use for the module. If None, no link is used.
106+ This can either be a Module or a string, options are:
107+ - "residual": Adds a residual connection to the module.
108+ - "shortcut": Adds a shortcut connection to the module.
109+ - "shortcut-concat": Adds a shortcut connection by concatenating
110+ the input and output.
90111
91112 Returns
92113 -------
@@ -189,7 +210,9 @@ def __init__(
189210 super ().__init__ (modules )
190211 self ._name : str = name
191212
192- def append_to (self , name : str , module : nn .Module ) -> "BlockContainerDict" :
213+ def append_to (
214+ self , name : str , module : nn .Module , link : Optional [LinkType ] = None
215+ ) -> "BlockContainerDict" :
193216 """Appends a module to a specified name.
194217
195218 Parameters
@@ -198,18 +221,27 @@ def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
198221 The name of the branch.
199222 module : nn.Module
200223 The module to append.
224+ link : Optional[LinkType]
225+ The link to use for the module. If None, no link is used.
226+ This can either be a Module or a string, options are:
227+ - "residual": Adds a residual connection to the module.
228+ - "shortcut": Adds a shortcut connection to the module.
229+ - "shortcut-concat": Adds a shortcut connection by concatenating
230+ the input and output.
201231
202232 Returns
203233 -------
204234 BlockContainerDict
205235 The current object itself.
206236 """
207237
208- self ._modules [name ].append (module )
238+ self ._modules [name ].append (module , link = link )
209239
210240 return self
211241
212- def prepend_to (self , name : str , module : nn .Module ) -> "BlockContainerDict" :
242+ def prepend_to (
243+ self , name : str , module : nn .Module , link : Optional [LinkType ] = None
244+ ) -> "BlockContainerDict" :
213245 """Prepends a module to a specified name.
214246
215247 Parameters
@@ -218,30 +250,25 @@ def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
218250 The name of the branch.
219251 module : nn.Module
220252 The module to prepend.
253+ link : Optional[LinkType]
254+ The link to use for the module. If None, no link is used.
255+ This can either be a Module or a string, options are:
256+ - "residual": Adds a residual connection to the module.
257+ - "shortcut": Adds a shortcut connection to the module.
258+ - "shortcut-concat": Adds a shortcut connection by concatenating
259+ the input and output.
221260
222261 Returns
223262 -------
224263 BlockContainerDict
225264 The current object itself.
226265 """
227266
228- self ._modules [name ].prepend (module )
229- def append_to (
230- self , name : str , module : nn .Module , link : Optional [LinkType ] = None
231- ) -> "BlockContainerDict" :
232- self ._modules [name ].append (module , link = link )
233-
234- return self
235-
236- def prepend_to (
237- self , name : str , module : nn .Module , link : Optional [LinkType ] = None
238- ) -> "BlockContainerDict" :
239267 self ._modules [name ].prepend (module , link = link )
240268
241- return self
242-
243- # Append to all branches, optionally copying
244- def append_for_each (self , module : nn .Module , shared = False ) -> "BlockContainerDict" :
269+ def append_for_each (
270+ self , module : nn .Module , shared = False , link : Optional [LinkType ] = None
271+ ) -> "BlockContainerDict" :
245272 """Appends a module to each branch.
246273
247274 Parameters
@@ -251,23 +278,29 @@ def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDic
251278 shared : bool, default=False
252279 If True, the same module is shared across all elements.
253280 Otherwise a deep copy of the module is used in each element.
281+ link : Optional[LinkType]
282+ The link to use for the module. If None, no link is used.
283+ This can either be a Module or a string, options are:
284+ - "residual": Adds a residual connection to the module.
285+ - "shortcut": Adds a shortcut connection to the module.
286+ - "shortcut-concat": Adds a shortcut connection by concatenating
287+ the input and output.
254288
255289 Returns
256290 -------
257291 BlockContainerDict
258292 The current object itself.
259293 """
260294
261- def append_for_each (
262- self , module : nn .Module , shared = False , link : Optional [LinkType ] = None
263- ) -> "BlockContainerDict" :
264295 for branch in self .values ():
265296 _module = module if shared else deepcopy (module )
266297 branch .append (_module , link = link )
267298
268299 return self
269300
270- def prepend_for_each (self , module : nn .Module , shared = False ) -> "BlockContainerDict" :
301+ def prepend_for_each (
302+ self , module : nn .Module , shared = False , link : Optional [LinkType ] = None
303+ ) -> "BlockContainerDict" :
271304 """Prepends a module to each branch.
272305
273306 Parameters
@@ -277,26 +310,25 @@ def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDi
277310 shared : bool, default=False
278311 If True, the same module is shared across all elements.
279312 Otherwise a deep copy of the module is used in each element.
313+ link : Optional[LinkType]
314+ The link to use for the module. If None, no link is used.
315+ This can either be a Module or a string, options are:
316+ - "residual": Adds a residual connection to the module.
317+ - "shortcut": Adds a shortcut connection to the module.
318+ - "shortcut-concat": Adds a shortcut connection by concatenating
319+ the input and output.
280320
281321 Returns
282322 -------
283323 BlockContainerDict
284324 The current object itself.
285325 """
286-
287- def prepend_for_each (
288- self , module : nn .Module , shared = False , link : Optional [LinkType ] = None
289- ) -> "BlockContainerDict" :
290326 for branch in self .values ():
291327 _module = module if shared else deepcopy (module )
292328 branch .prepend (_module , link = link )
293329
294330 return self
295331
296- # This assumes same branches, we append it's content to each branch
297- # def append_parallel(self, module: "BlockContainerDict") -> "BlockContainerDict":
298- # ...
299-
300332 def add_module (self , name : str , module : Optional [nn .Module ]) -> None :
301333 if module and not isinstance (module , BlockContainer ):
302334 module = BlockContainer (module , name = name [0 ].upper () + name [1 :])
0 commit comments