99from Orange .regression import RandomForestRegressionLearner
1010from Orange .tests import test_filename
1111
12+
1213class RandomForestTest (unittest .TestCase ):
1314 @classmethod
1415 def setUpClass (cls ):
1516 cls .iris = Table ('iris' )
16- cls .house = Table ('housing' )
17+ cls .housing = Table ('housing' )
1718
1819 def test_RandomForest (self ):
1920 forest = RandomForestLearner ()
@@ -43,28 +44,28 @@ def test_predict_numpy(self):
4344
4445 def test_RandomForestRegression (self ):
4546 forest = RandomForestRegressionLearner ()
46- results = CrossValidation (self .house , [forest ], k = 10 )
47+ results = CrossValidation (self .housing , [forest ], k = 10 )
4748 _ = RMSE (results )
4849
4950 def test_predict_single_instance_reg (self ):
5051 forest = RandomForestRegressionLearner ()
51- model = forest (self .house )
52- for ins in self .house :
52+ model = forest (self .housing )
53+ for ins in self .housing :
5354 pred = model (ins )
5455 self .assertGreater (pred , 0 )
5556
5657 def test_predict_table_reg (self ):
5758 forest = RandomForestRegressionLearner ()
58- model = forest (self .house )
59- pred = model (self .house )
60- self .assertEqual (len (self .house ), len (pred ))
59+ model = forest (self .housing )
60+ pred = model (self .housing )
61+ self .assertEqual (len (self .housing ), len (pred ))
6162 self .assertGreater (all (pred ), 0 )
6263
6364 def test_predict_numpy_reg (self ):
6465 forest = RandomForestRegressionLearner ()
65- model = forest (self .house )
66- pred = model (self .house .X )
67- self .assertEqual (len (self .house ), len (pred ))
66+ model = forest (self .housing )
67+ pred = model (self .housing .X )
68+ self .assertEqual (len (self .housing ), len (pred ))
6869 self .assertGreater (all (pred ), 0 )
6970
7071 def test_classification_scorer (self ):
@@ -78,9 +79,9 @@ def test_classification_scorer(self):
7879
7980 def test_regression_scorer (self ):
8081 learner = RandomForestRegressionLearner ()
81- scores = learner .score_data (self .house )
82+ scores = learner .score_data (self .housing )
8283 self .assertEqual (['LSTAT' , 'RM' ],
83- sorted ([self .house .domain .attributes [i ].name
84+ sorted ([self .housing .domain .attributes [i ].name
8485 for i in np .argsort (scores [0 ])[- 2 :]]))
8586
8687 def test_scorer_feature (self ):
@@ -92,3 +93,19 @@ def test_scorer_feature(self):
9293 np .random .seed (42 )
9394 score = learner .score_data (data , attr )
9495 np .testing .assert_array_almost_equal (score , scores [:, i ])
96+
97+ def test_get_classification_trees (self ):
98+ n = 5
99+ forest = RandomForestLearner (n_estimators = n )
100+ model = forest (self .iris )
101+ self .assertEqual (len (model .trees ), n )
102+ tree = model .trees [0 ]
103+ self .assertEqual (tree (self .iris [0 ]), 0 )
104+
105+ def test_get_regression_trees (self ):
106+ n = 5
107+ forest = RandomForestRegressionLearner (n_estimators = n )
108+ model = forest (self .housing )
109+ self .assertEqual (len (model .trees ), n )
110+ tree = model .trees [0 ]
111+ tree (self .housing [0 ])
0 commit comments