@@ -108,15 +108,13 @@ def test_transfer_tree():
108108 transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "ser" )
109109 transferred_rf ._ser_rf (Xt , yt ,original_ser = False ,no_red_on_cl = True ,cl_no_red = [0 ])
110110
111- # WARNING! Error Raised with this test
112111 if method == 'ser_no_ext' :
113- pass
114112 #decision tree
115- # transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="ser")
116- # transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_ext_on_cl=True,cl_no_red =[0],ext_cond=True)
113+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "ser" )
114+ transferred_dt ._ser (Xt , yt ,node = 0 ,original_ser = False ,no_ext_on_cl = True ,cl_no_ext = [0 ],ext_cond = True )
117115 #random forest
118- # transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="ser")
119- # transferred_rf._ser_rf(Xt, yt,original_ser=False,no_ext_on_cl=True,cl_no_ext=[0],ext_cond=True)
116+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "ser" )
117+ transferred_rf ._ser_rf (Xt , yt ,original_ser = False ,no_ext_on_cl = True ,cl_no_ext = [0 ],ext_cond = True )
120118 if method == 'ser_nr_lambda' :
121119 #decision tree
122120 transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "ser" )
@@ -159,7 +157,7 @@ def test_transfer_tree():
159157 root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
160158 #random forest
161159 transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
162- transferred_rf ._strut (Xt , yt ,adapt_prop = False ,no_prune_on_cl = True ,cl_no_prune = [0 ],
160+ transferred_rf ._strut_rf (Xt , yt ,adapt_prop = False ,no_prune_on_cl = True ,cl_no_prune = [0 ],
163161 leaf_loss_quantify = False ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = False ,
164162 root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
165163 if method == 'strut_lambda_np' :
@@ -170,7 +168,7 @@ def test_transfer_tree():
170168 root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
171169 #random forest
172170 transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
173- transferred_rf ._strut (Xt , yt ,adapt_prop = True ,no_prune_on_cl = True ,cl_no_prune = [0 ],
171+ transferred_rf ._strut_rf (Xt , yt ,adapt_prop = True ,no_prune_on_cl = True ,cl_no_prune = [0 ],
174172 leaf_loss_quantify = True ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = False ,
175173 root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
176174 if method == 'strut_lambda_np2' :
@@ -179,12 +177,11 @@ def test_transfer_tree():
179177 transferred_dt ._strut (Xt , yt ,node = 0 ,adapt_prop = False ,no_prune_on_cl = True ,cl_no_prune = [0 ],
180178 leaf_loss_quantify = False ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = False ,
181179 root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
182- # Warning! Error Raised because `strut` not in TransferForest
183180 #random forest
184- # transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
185- # transferred_rf._strut (Xt, yt,adapt_prop=True,no_prune_on_cl=True,cl_no_prune=[0],
186- # leaf_loss_quantify=True,leaf_loss_threshold=0.5,no_prune_with_translation=True,
187- # root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
181+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
182+ transferred_rf ._strut_rf (Xt , yt ,adapt_prop = True ,no_prune_on_cl = True ,cl_no_prune = [0 ],
183+ leaf_loss_quantify = True ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = True ,
184+ root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
188185
189186 score = transferred_dt .estimator .score (Xt_test , yt_test )
190187 #score = clf_transfer.score(Xt_test, yt_test)
0 commit comments