@@ -343,9 +343,7 @@ def get_metrics(self, inputs_ys, inputs_yt,
343343 return metrics
344344
345345
346- def _build (self , shape_Xs , shape_ys ,
347- shape_Xt , shape_yt ):
348-
346+ def _initialize_networks (self , shape_Xt ):
349347 # Call predict to avoid strange behaviour with
350348 # Sequential model whith unspecified input_shape
351349 zeros_enc_ = self .encoder_ .predict (np .zeros ((1 ,) + shape_Xt ));
@@ -357,39 +355,6 @@ def _build(self, shape_Xs, shape_ys,
357355 np .expand_dims (zeros_task_ , 1 ))
358356 zeros_mapping_ = np .reshape (zeros_mapping_ , (1 , - 1 ))
359357 self .discriminator_ .predict (zeros_mapping_ );
360-
361- inputs_Xs = Input (shape_Xs )
362- inputs_ys = Input (shape_ys )
363- inputs_Xt = Input (shape_Xt )
364-
365- if shape_yt is None :
366- inputs_yt = None
367- inputs = [inputs_Xs , inputs_ys , inputs_Xt ]
368- else :
369- inputs_yt = Input (shape_yt )
370- inputs = [inputs_Xs , inputs_ys ,
371- inputs_Xt , inputs_yt ]
372-
373- outputs = self .create_model (inputs_Xs = inputs_Xs ,
374- inputs_Xt = inputs_Xt )
375-
376- self .model_ = Model (inputs , outputs )
377-
378- loss = self .get_loss (inputs_ys = inputs_ys ,
379- ** outputs )
380- metrics = self .get_metrics (inputs_ys = inputs_ys ,
381- inputs_yt = inputs_yt ,
382- ** outputs )
383-
384- self .model_ .add_loss (loss )
385- for k in metrics :
386- self .model_ .add_metric (tf .reduce_mean (metrics [k ]),
387- name = k , aggregation = "mean" )
388-
389- tf .compat .v1 .logging .set_verbosity (tf .compat .v1 .logging .ERROR )
390- self .model_ .compile (optimizer = self .optimizer )
391- self .history_ = {}
392- return self
393358
394359
395360 def predict_disc (self , X ):
0 commit comments