@@ -244,14 +244,17 @@ def __init__(self,
244244 self .random_state = random_state
245245
246246
247+ def _initialize_networks (self , shape_Xt ):
248+ zeros_enc_ = self .encoder_ .predict (np .zeros ((1 ,) + shape_Xt ));
249+ self .task_ .predict (zeros_enc_ );
250+ self .discriminator_ .predict (zeros_enc_ );
251+
252+
247253 def _build (self , shape_Xs , shape_ys ,
248254 shape_Xt , shape_yt ):
249-
250255 # Call predict to avoid strange behaviour with
251256 # Sequential model whith unspecified input_shape
252- zeros_enc_ = self .encoder_ .predict (np .zeros ((1 ,) + shape_Xt ));
253- self .task_ .predict (zeros_enc_ );
254- self .discriminator_ .predict (zeros_enc_ );
257+ self ._initialize_networks (shape_Xt )
255258
256259 inputs_Xs = Input (shape_Xs )
257260 inputs_ys = Input (shape_ys )
@@ -271,6 +274,7 @@ def _build(self, shape_Xs, shape_ys,
271274 self .model_ = Model (inputs , outputs )
272275
273276 loss = self .get_loss (inputs_ys = inputs_ys ,
277+ inputs_yt = inputs_yt ,
274278 ** outputs )
275279 metrics = self .get_metrics (inputs_ys = inputs_ys ,
276280 inputs_yt = inputs_yt ,
@@ -390,14 +394,17 @@ def create_model(self, inputs_Xs, inputs_Xt):
390394 pass
391395
392396
393- def get_loss (self , inputs_ys , ** ouputs ):
397+ def get_loss (self , inputs_ys , inputs_yt , ** ouputs ):
394398 """
395399 Get loss.
396400
397401 Parameters
398402 ----------
399403 inputs_ys : InputLayer
400404 Input layer for ys entries.
405+
406+ inputs_yt : InputLayer
407+ Input layer for yt entries.
401408
402409 outputs : dict of tf Tensors
403410 Model outputs tensors.
0 commit comments