@@ -69,7 +69,7 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
6969 # is resolved, a flat version of the heads dictionary is kept.
7070 # This allows to save head weights properly, see for reference
7171 # https://github.com/keras-team/keras/blob/v3.3.3/keras/src/saving/saving_lib.py#L481.
72- # A nested heads dict is still prefered over this flat dict,
72+ # A nested heads dict is still preferred over this flat dict,
7373 # because it avoids string operation based filtering in `self._forward()`.
7474 flat_key = f"{ score_key } ___{ head_key } "
7575 self .heads_flat [flat_key ] = head
@@ -96,17 +96,12 @@ def call(
9696 if xz is None and not self .built :
9797 raise ValueError ("Cannot build inference network without inference variables." )
9898 if conditions is None : # unconditional estimation uses a fixed input vector
99- conditions = keras .ops .convert_to_tensor ([[1.0 ]], dtype = "float32" )
100- return self ._forward (conditions = conditions , training = training , ** kwargs )
99+ conditions = keras .ops .convert_to_tensor ([[1.0 ]], dtype = keras .ops .dtype (xz ))
101100
102- def _forward (
103- self ,
104- conditions : Tensor = None ,
105- training : bool = False ,
106- ** kwargs ,
107- ) -> dict [str , Tensor ]:
101+ # pass conditions to the shared subnet
108102 output = self .subnet (conditions , training = training )
109103
104+ # pass along to calculate individual head outputs
110105 output = {
111106 score_key : {head_key : head (output , training = training ) for head_key , head in self .heads [score_key ].items ()}
112107 for score_key in self .heads .keys ()
0 commit comments