@@ -465,30 +465,33 @@ def test_output_model_name(self):
465465 self .widget .apply_button .button .click ()
466466 self .assertEqual (self .get_output (self .model_name ).name , new_name )
467467
468+ def _get_param_value (self , learner , param ):
469+ if isinstance (learner , Fitter ):
470+ # Both is just a was to indicate to the tests, fitters don't
471+ # actually support this
472+ if param .problem_type == "both" :
473+ problem_type = learner .CLASSIFICATION
474+ else :
475+ problem_type = param .problem_type
476+ return learner .get_params (problem_type ).get (param .name )
477+ else :
478+ return learner .params .get (param .name )
479+
468480 def test_parameters_default (self ):
469481 """Check if learner's parameters are set to default (widget's) values
470482 """
471483 for dataset in self .valid_datasets :
472484 self .send_signal ("Data" , dataset )
473485 self .widget .apply_button .button .click ()
474- if hasattr (self .widget .learner , "params" ):
475- learner_params = self .widget .learner .params
476- for parameter in self .parameters :
477- # Skip if the param isn't used for the given data type
478- if self ._should_check_parameter (parameter , dataset ):
479- self .assertEqual (learner_params .get (parameter .name ),
480- parameter .get_value ())
486+ for parameter in self .parameters :
487+ # Skip if the param isn't used for the given data type
488+ if self ._should_check_parameter (parameter , dataset ):
489+ self .assertEqual (
490+ self ._get_param_value (self .widget .learner , parameter ),
491+ parameter .get_value ())
481492
482493 def test_parameters (self ):
483494 """Check learner and model for various values of all parameters"""
484-
485- def get_value (learner , name ):
486- # Handle SKL and skl-like learners, and non-SKL learners
487- if hasattr (learner , "params" ):
488- return learner .params .get (name )
489- else :
490- return getattr (learner , name )
491-
492495 # Test params on every valid dataset, since some attributes may apply
493496 # to only certain problem types
494497 for dataset in self .valid_datasets :
@@ -504,24 +507,22 @@ def get_value(learner, name):
504507 for value in parameter .values :
505508 parameter .set_value (value )
506509 self .widget .apply_button .button .click ()
507- param = get_value (self .widget .learner , parameter . name )
510+ param = self . _get_param_value (self .widget .learner , parameter )
508511 self .assertEqual (
509512 param , parameter .get_value (),
510513 "Mismatching setting for parameter '%s'" % parameter )
511514 self .assertEqual (
512515 param , value ,
513516 "Mismatching setting for parameter '%s'" % parameter )
514- param = get_value (self .get_output ("Learner" ),
515- parameter .name )
517+ param = self ._get_param_value (self .get_output ("Learner" ), parameter )
516518 self .assertEqual (
517519 param , value ,
518520 "Mismatching setting for parameter '%s'" % parameter )
519521
520522 if issubclass (self .widget .LEARNER , SklModel ):
521523 model = self .get_output (self .model_name )
522524 if model is not None :
523- self .assertEqual (get_value (model , parameter .name ),
524- value )
525+ self .assertEqual (self ._get_param_value (model , parameter ), value )
525526 self .assertFalse (self .widget .Error .active )
526527 else :
527528 self .assertTrue (self .widget .Error .active )
0 commit comments