Skip to content

Commit c026851

Browse files
committed
Fixing doc-strings
1 parent 2542d04 commit c026851

File tree

1 file changed

+62
-30
lines changed

1 file changed

+62
-30
lines changed

merlin/models/torch/container.py

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ def append(self, module: nn.Module, link: Optional[Link] = None):
5454
----------
5555
module : nn.Module
5656
The PyTorch module to be appended.
57+
link : Optional[LinkType]
58+
The link to use for the module. If None, no link is used.
59+
This can either be a Module or a string, options are:
60+
- "residual": Adds a residual connection to the module.
61+
- "shortcut": Adds a shortcut connection to the module.
62+
- "shortcut-concat": Adds a shortcut connection by concatenating
63+
the input and output.
5764
5865
Returns
5966
-------
@@ -71,6 +78,13 @@ def prepend(self, module: nn.Module, link: Optional[Link] = None):
7178
----------
7279
module : nn.Module
7380
The PyTorch module to be prepended.
81+
link : Optional[LinkType]
82+
The link to use for the module. If None, no link is used.
83+
This can either be a Module or a string, options are:
84+
- "residual": Adds a residual connection to the module.
85+
- "shortcut": Adds a shortcut connection to the module.
86+
- "shortcut-concat": Adds a shortcut connection by concatenating
87+
the input and output.
7488
7589
Returns
7690
-------
@@ -87,6 +101,13 @@ def insert(self, index: int, module: nn.Module, link: Optional[Link] = None):
87101
The index at which the module is to be inserted.
88102
module : nn.Module
89103
The PyTorch module to be inserted.
104+
link : Optional[LinkType]
105+
The link to use for the module. If None, no link is used.
106+
This can either be a Module or a string, options are:
107+
- "residual": Adds a residual connection to the module.
108+
- "shortcut": Adds a shortcut connection to the module.
109+
- "shortcut-concat": Adds a shortcut connection by concatenating
110+
the input and output.
90111
91112
Returns
92113
-------
@@ -189,7 +210,9 @@ def __init__(
189210
super().__init__(modules)
190211
self._name: str = name
191212

192-
def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
213+
def append_to(
214+
self, name: str, module: nn.Module, link: Optional[LinkType] = None
215+
) -> "BlockContainerDict":
193216
"""Appends a module to a specified name.
194217
195218
Parameters
@@ -198,18 +221,27 @@ def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
198221
The name of the branch.
199222
module : nn.Module
200223
The module to append.
224+
link : Optional[LinkType]
225+
The link to use for the module. If None, no link is used.
226+
This can either be a Module or a string, options are:
227+
- "residual": Adds a residual connection to the module.
228+
- "shortcut": Adds a shortcut connection to the module.
229+
- "shortcut-concat": Adds a shortcut connection by concatenating
230+
the input and output.
201231
202232
Returns
203233
-------
204234
BlockContainerDict
205235
The current object itself.
206236
"""
207237

208-
self._modules[name].append(module)
238+
self._modules[name].append(module, link=link)
209239

210240
return self
211241

212-
def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
242+
def prepend_to(
243+
self, name: str, module: nn.Module, link: Optional[LinkType] = None
244+
) -> "BlockContainerDict":
213245
"""Prepends a module to a specified name.
214246
215247
Parameters
@@ -218,30 +250,25 @@ def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict":
218250
The name of the branch.
219251
module : nn.Module
220252
The module to prepend.
253+
link : Optional[LinkType]
254+
The link to use for the module. If None, no link is used.
255+
This can either be a Module or a string, options are:
256+
- "residual": Adds a residual connection to the module.
257+
- "shortcut": Adds a shortcut connection to the module.
258+
- "shortcut-concat": Adds a shortcut connection by concatenating
259+
the input and output.
221260
222261
Returns
223262
-------
224263
BlockContainerDict
225264
The current object itself.
226265
"""
227266

228-
self._modules[name].prepend(module)
229-
def append_to(
230-
self, name: str, module: nn.Module, link: Optional[LinkType] = None
231-
) -> "BlockContainerDict":
232-
self._modules[name].append(module, link=link)
233-
234-
return self
235-
236-
def prepend_to(
237-
self, name: str, module: nn.Module, link: Optional[LinkType] = None
238-
) -> "BlockContainerDict":
239267
self._modules[name].prepend(module, link=link)
240268

241-
return self
242-
243-
# Append to all branches, optionally copying
244-
def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict":
269+
def append_for_each(
270+
self, module: nn.Module, shared=False, link: Optional[LinkType] = None
271+
) -> "BlockContainerDict":
245272
"""Appends a module to each branch.
246273
247274
Parameters
@@ -251,23 +278,29 @@ def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDic
251278
shared : bool, default=False
252279
If True, the same module is shared across all elements.
253280
Otherwise a deep copy of the module is used in each element.
281+
link : Optional[LinkType]
282+
The link to use for the module. If None, no link is used.
283+
This can either be a Module or a string, options are:
284+
- "residual": Adds a residual connection to the module.
285+
- "shortcut": Adds a shortcut connection to the module.
286+
- "shortcut-concat": Adds a shortcut connection by concatenating
287+
the input and output.
254288
255289
Returns
256290
-------
257291
BlockContainerDict
258292
The current object itself.
259293
"""
260294

261-
def append_for_each(
262-
self, module: nn.Module, shared=False, link: Optional[LinkType] = None
263-
) -> "BlockContainerDict":
264295
for branch in self.values():
265296
_module = module if shared else deepcopy(module)
266297
branch.append(_module, link=link)
267298

268299
return self
269300

270-
def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict":
301+
def prepend_for_each(
302+
self, module: nn.Module, shared=False, link: Optional[LinkType] = None
303+
) -> "BlockContainerDict":
271304
"""Prepends a module to each branch.
272305
273306
Parameters
@@ -277,26 +310,25 @@ def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDi
277310
shared : bool, default=False
278311
If True, the same module is shared across all elements.
279312
Otherwise a deep copy of the module is used in each element.
313+
link : Optional[LinkType]
314+
The link to use for the module. If None, no link is used.
315+
This can either be a Module or a string, options are:
316+
- "residual": Adds a residual connection to the module.
317+
- "shortcut": Adds a shortcut connection to the module.
318+
- "shortcut-concat": Adds a shortcut connection by concatenating
319+
the input and output.
280320
281321
Returns
282322
-------
283323
BlockContainerDict
284324
The current object itself.
285325
"""
286-
287-
def prepend_for_each(
288-
self, module: nn.Module, shared=False, link: Optional[LinkType] = None
289-
) -> "BlockContainerDict":
290326
for branch in self.values():
291327
_module = module if shared else deepcopy(module)
292328
branch.prepend(_module, link=link)
293329

294330
return self
295331

296-
# This assumes same branches, we append it's content to each branch
297-
# def append_parallel(self, module: "BlockContainerDict") -> "BlockContainerDict":
298-
# ...
299-
300332
def add_module(self, name: str, module: Optional[nn.Module]) -> None:
301333
if module and not isinstance(module, BlockContainer):
302334
module = BlockContainer(module, name=name[0].upper() + name[1:])

0 commit comments

Comments
 (0)