@@ -21,7 +21,7 @@ def __init__(
2121 scoring_rules : dict [str , ScoringRule ],
2222 body_subnet : str | type = "mlp" , # naming: shared_subnet / body / subnet ?
2323 heads_subnet : dict [str , str | keras .Layer ] = None , # TODO: `type` instead of `keras.Layer` ? Too specific ?
24- activations : dict [str , keras .layers . Activation | Callable | str ] = None ,
24+ activations : dict [str , keras .Layer | Callable | str ] = None ,
2525 ** kwargs ,
2626 ):
2727 super ().__init__ (
@@ -36,17 +36,17 @@ def __init__(
3636
3737 self .body_subnet = find_network (body_subnet , ** kwargs .get ("body_subnet_kwargs" , {}))
3838
39- if heads_subnet :
39+ if heads_subnet is not None :
4040 self .heads = {
4141 key : [find_network (value , ** kwargs .get ("heads_subnet_kwargs" , {}).get (key , {}))]
4242 for key , value in heads_subnet .items ()
4343 }
4444 else :
4545 self .heads = {key : [] for key in self .scoring_rules .keys ()}
4646
47- if activations :
47+ if activations is not None :
4848 self .activations = {
49- key : (value if isinstance (value , keras .layers . Activation ) else keras .layers .Activation (value ))
49+ key : (value if isinstance (value , keras .Layer ) else keras .layers .Activation (value ))
5050 for key , value in activations .items ()
5151 } # make sure that each value is an Activation object
5252 else :
@@ -64,16 +64,16 @@ def __init__(
6464
6565 assert set (self .scoring_rules .keys ()) == set (self .heads .keys ()) == set (self .activations .keys ())
6666
67- def build (self , xz_shape : Shape , conditions_shape : Shape = None ) -> None :
67+ def build (self , xz_shape : Shape , conditions_shape : Shape ) -> None :
6868 # build the shared body network
6969 input_shape = conditions_shape
7070 self .body_subnet .build (input_shape )
7171 body_output_shape = self .body_subnet .compute_output_shape (input_shape )
7272
7373 for key in self .heads .keys ():
74- # head_output_shape (excluding batch_size) convention is (*prediction_shape , *parameter_block_shape)
75- prediction_shape = self .scoring_rules [key ].prediction_shape
76- head_output_shape = prediction_shape + xz_shape [1 :]
74+ # head_output_shape (excluding batch_size) convention is (*target_shape , *parameter_block_shape)
75+ target_shape = self .scoring_rules [key ].target_shape
76+ head_output_shape = target_shape + xz_shape [1 :]
7777
7878 # set correct head shape
7979 self .heads [key ][- 3 ].units = prod (head_output_shape )
@@ -91,13 +91,18 @@ def call(
9191 conditions : Tensor = None ,
9292 training : bool = False ,
9393 ** kwargs ,
94- ) -> Tensor | tuple [ Tensor , Tensor ]:
94+ ) -> dict [ str , Tensor ]:
9595 # TODO: remove unnecessary simularity with InferenceNetwork
9696 return self ._forward (xz , conditions = conditions , training = training , ** kwargs )
9797
9898 def _forward (
99- self , x : Tensor , conditions : Tensor = None , training : bool = False , ** kwargs
100- ) -> Tensor | tuple [Tensor , Tensor ]:
99+ self ,
100+ x : Tensor ,
101+ conditions : Tensor = None ,
102+ training : bool = False ,
103+ ** kwargs ,
104+ # TODO: propagate training flag
105+ ) -> dict [str , Tensor ]:
101106 body_output = self .body_subnet (conditions )
102107
103108 output = dict ()
0 commit comments