Skip to content

Commit ed7caee

Browse files
authored
Adding registry (#1090)
* Increase test-coverage * Adding ParallelBlock * Adding registry * Fix merge conflicts
1 parent 2c6eafc commit ed7caee

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

merlin/models/torch/block.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222

2323
from merlin.models.torch.batch import Batch
2424
from merlin.models.torch.container import BlockContainer, BlockContainerDict
25+
from merlin.models.torch.registry import registry
26+
from merlin.models.utils.registry import RegistryMixin
2527

2628

27-
class Block(BlockContainer):
29+
class Block(BlockContainer, RegistryMixin):
2830
"""A base-class that calls it's modules sequentially.
2931
3032
Parameters
@@ -35,6 +37,8 @@ class Block(BlockContainer):
3537
The name of the block. If None, no name is assigned.
3638
"""
3739

40+
registry = registry
41+
3842
def __init__(self, *module: nn.Module, name: Optional[str] = None):
3943
super().__init__(*module, name=name)
4044

merlin/models/torch/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from merlin.models.utils.registry import Registry
2+
3+
registry: Registry = Registry.class_registry("modules")

tests/unit/torch/test_block.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ def test_repeat(self):
8282
with pytest.raises(ValueError, match="n must be greater than 0"):
8383
block.repeat(0)
8484

85+
def test_from_registry(self):
86+
@Block.registry.register("my_block")
87+
class MyBlock(Block):
88+
def forward(self, inputs):
89+
_inputs = inputs + 1
90+
91+
return super().forward(_inputs)
92+
93+
block = Block.parse("my_block")
94+
assert block.__class__ == MyBlock
95+
96+
inputs = torch.randn(1, 3)
97+
assert torch.equal(block(inputs), inputs + 1)
98+
8599

86100
class TestParallelBlock:
87101
def test_init(self):

0 commit comments

Comments
 (0)