@@ -157,6 +157,8 @@ def __call__(self, data, outcome, treatment, y_task, x_task,
157157 is_discrete_treatment = x_task if isinstance (x_task , bool ) else x_task != const .TASK_REGRESSION
158158 is_discrete_outcome = y_task if isinstance (y_task , bool ) else y_task != const .TASK_REGRESSION
159159
160+ assert is_discrete_treatment , 'SLearner support discrete treatment only.'
161+
160162 return PermutedSLearner (
161163 model = self ._model (data , task = y_task , estimator = self .model , random_state = random_state ),
162164 is_discrete_outcome = is_discrete_outcome ,
@@ -179,6 +181,8 @@ def __call__(self, data, outcome, treatment, y_task, x_task,
179181 is_discrete_treatment = x_task if isinstance (x_task , bool ) else x_task != const .TASK_REGRESSION
180182 is_discrete_outcome = y_task if isinstance (y_task , bool ) else y_task != const .TASK_REGRESSION
181183
184+ assert is_discrete_treatment , 'TLearner support discrete treatment only.'
185+
182186 return PermutedTLearner (
183187 model = self ._model (data , task = y_task , estimator = self .model , random_state = random_state ),
184188 is_discrete_outcome = is_discrete_outcome ,
@@ -202,6 +206,8 @@ def __call__(self, data, outcome, treatment, y_task, x_task,
202206 is_discrete_treatment = x_task if isinstance (x_task , bool ) else x_task != const .TASK_REGRESSION
203207 is_discrete_outcome = y_task if isinstance (y_task , bool ) else y_task != const .TASK_REGRESSION
204208
209+ assert is_discrete_treatment , 'XLearner support discrete treatment only.'
210+
205211 if is_discrete_outcome :
206212 final_proba_model = self ._model (
207213 data , task = const .TASK_REGRESSION , estimator = self .final_proba_model , random_state = random_state )
@@ -229,12 +235,41 @@ def __call__(self, data, outcome, treatment, y_task, x_task,
229235 adjustment = None , covariate = None , instrument = None , random_state = None ):
230236 from ylearn .estimator_model ._permuted import PermutedCausalTree
231237
238+ is_discrete_treatment = x_task if isinstance (x_task , bool ) else x_task != const .TASK_REGRESSION
239+ is_discrete_outcome = y_task if isinstance (y_task , bool ) else y_task != const .TASK_REGRESSION
240+
241+ assert is_discrete_treatment , 'CausalTree support discrete treatment only.'
242+
232243 options = self .options .copy ()
233244 if random_state is not None :
234245 options ['random_state' ] = random_state
246+ # options['is_discrete_outcome'] = is_discrete_outcome
247+ # options['is_discrete_treatment'] = is_discrete_treatment
248+
235249 return PermutedCausalTree (** options )
236250
237251
252+ @register ()
253+ class GrfFactory (BaseEstimatorFactory ):
254+ def __init__ (self , ** kwargs ):
255+ self .options = kwargs .copy ()
256+
257+ def __call__ (self , data , outcome , treatment , y_task , x_task ,
258+ adjustment = None , covariate = None , instrument = None , random_state = None ):
259+ from ylearn .estimator_model ._generalized_forest import GRForest
260+
261+ is_discrete_treatment = x_task if isinstance (x_task , bool ) else x_task != const .TASK_REGRESSION
262+ is_discrete_outcome = y_task if isinstance (y_task , bool ) else y_task != const .TASK_REGRESSION
263+
264+ options = self .options .copy ()
265+ if random_state is not None :
266+ options ['random_state' ] = random_state
267+ options ['is_discrete_outcome' ] = is_discrete_outcome
268+ options ['is_discrete_treatment' ] = is_discrete_treatment
269+
270+ return GRForest (** options )
271+
272+
238273@register ()
239274@register (name = 'bound' )
240275class ApproxBoundFactory (BaseEstimatorFactory ):
0 commit comments