File tree Expand file tree Collapse file tree 3 files changed +22
-1
lines changed
Expand file tree Collapse file tree 3 files changed +22
-1
lines changed Original file line number Diff line number Diff line change 2222
2323from merlin .models .torch .batch import Batch
2424from merlin .models .torch .container import BlockContainer , BlockContainerDict
25+ from merlin .models .torch .registry import registry
26+ from merlin .models .utils .registry import RegistryMixin
2527
2628
27- class Block (BlockContainer ):
29+ class Block (BlockContainer , RegistryMixin ):
2830 """A base-class that calls it's modules sequentially.
2931
3032 Parameters
@@ -35,6 +37,8 @@ class Block(BlockContainer):
3537 The name of the block. If None, no name is assigned.
3638 """
3739
40+ registry = registry
41+
3842 def __init__ (self , * module : nn .Module , name : Optional [str ] = None ):
3943 super ().__init__ (* module , name = name )
4044
Original file line number Diff line number Diff line change 1+ from merlin .models .utils .registry import Registry
2+
3+ registry : Registry = Registry .class_registry ("modules" )
Original file line number Diff line number Diff line change @@ -82,6 +82,20 @@ def test_repeat(self):
8282 with pytest .raises (ValueError , match = "n must be greater than 0" ):
8383 block .repeat (0 )
8484
85+ def test_from_registry (self ):
86+ @Block .registry .register ("my_block" )
87+ class MyBlock (Block ):
88+ def forward (self , inputs ):
89+ _inputs = inputs + 1
90+
91+ return super ().forward (_inputs )
92+
93+ block = Block .parse ("my_block" )
94+ assert block .__class__ == MyBlock
95+
96+ inputs = torch .randn (1 , 3 )
97+ assert torch .equal (block (inputs ), inputs + 1 )
98+
8599
86100class TestParallelBlock :
87101 def test_init (self ):
You can’t perform that action at this time.
0 commit comments