@@ -772,6 +772,54 @@ def forward(self, input: Tensor) -> Tensor:
772772 return input @ self .linear .weight .T + self .linear .bias
773773
774774
775+ class WithModuleWithStringArg (ShapedModule ):
776+ """Model containing a module that has a string argument."""
777+
778+ INPUT_SHAPES = (2 ,)
779+ OUTPUT_SHAPES = (3 ,)
780+
781+ class WithStringArg (nn .Module ):
782+ def __init__ (self ):
783+ super ().__init__ ()
784+ self .matrix = nn .Parameter (torch .randn (2 , 3 ))
785+
786+ def forward (self , s : str , input : Tensor ) -> Tensor :
787+ if s == "two" :
788+ return input @ self .matrix * 2.0
789+ else :
790+ return input @ self .matrix
791+
792+ def __init__ (self ):
793+ super ().__init__ ()
794+ self .with_string_arg = self .WithStringArg ()
795+
796+ def forward (self , input : Tensor ) -> Tensor :
797+ return self .with_string_arg ("two" , input )
798+
799+
800+ class WithModuleWithStringOutput (ShapedModule ):
801+ """Model containing a module that has a string output."""
802+
803+ INPUT_SHAPES = (2 ,)
804+ OUTPUT_SHAPES = (3 ,)
805+
806+ class WithStringOutput (nn .Module ):
807+ def __init__ (self ):
808+ super ().__init__ ()
809+ self .matrix = nn .Parameter (torch .randn (2 , 3 ))
810+
811+ def forward (self , input : Tensor ) -> tuple [str , Tensor ]:
812+ return "test" , input @ self .matrix
813+
814+ def __init__ (self ):
815+ super ().__init__ ()
816+ self .with_string_output = self .WithStringOutput ()
817+
818+ def forward (self , input : Tensor ) -> Tensor :
819+ _ , output = self .with_string_output (input )
820+ return output
821+
822+
775823class FreeParam (ShapedModule ):
776824 """
777825 Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the
0 commit comments