@@ -206,8 +206,8 @@ def make_experiment(hyper_model_cls,
206206 - nf
207207 optimize_direction : str, optional
208208 Hypernets search reward metric direction, default is detected from reward_metric.
209- discriminator : instance of hypernets.discriminator.BaseDiscriminator, optional
210- Discriminator is used to determine whether to continue training
209+ discriminator : instance of hypernets.discriminator.BaseDiscriminator or bool , optional
210+ Discriminator is used to determine whether to continue training, set False to disable it.
211211 hyper_model_options: dict, optional
212212 Options to initlize HyperModel except *reward_metric*, *task*, *callbacks*, *discriminator*.
213213 evaluation_metrics: str, list, or None (default='auto'),
@@ -365,10 +365,14 @@ def append_early_stopping_callbacks(cbs):
365365 report_render = to_report_render_object (report_render , report_render_options )
366366 callbacks .append (MLReportCallback (report_render ))
367367
368- if discriminator is None and cfg .experiment_discriminator is not None and len (cfg .experiment_discriminator ) > 0 :
368+ if ((discriminator is None or discriminator is True )
369+ and cfg .experiment_discriminator is not None
370+ and len (cfg .experiment_discriminator ) > 0 ):
369371 discriminator = make_discriminator (cfg .experiment_discriminator ,
370372 optimize_direction = optimize_direction ,
371373 ** (cfg .experiment_discriminator_options or {}))
374+ elif discriminator is False :
375+ discriminator = None
372376
373377 if id is None :
374378 hasher = tb .data_hasher ()
0 commit comments