88
99import joblib
1010import lightgbm as lgb
11+ from numba .core .utils import erase_traceback
1112import numpy as np
1213from numpy .lib .function_base import iterable
1314import pandas as pd
14- from sklearn .base import TransformerMixin
15+ from sklearn .base import BaseEstimator , TransformerMixin
1516from sklearn .compose import ColumnTransformer
1617from sklearn .ensemble import GradientBoostingClassifier , RandomForestClassifier , RandomForestRegressor
1718from sklearn .linear_model import Lasso , LassoCV , LogisticRegression , LogisticRegressionCV
@@ -172,7 +173,7 @@ def _first_stage_clf(X, y, *, make_regressor=False, automl=True, min_count=None,
172173 else :
173174 model = LogisticRegressionCV (
174175 cv = min (5 , min_count ), max_iter = 1000 , Cs = cs , random_state = random_state ).fit (X , y )
175- est = LogisticRegression (C = model .C_ [0 ], random_state = random_state )
176+ est = LogisticRegression (C = model .C_ [0 ], max_iter = 1000 , random_state = random_state )
176177 if make_regressor :
177178 return _RegressionWrapper (est )
178179 else :
@@ -192,8 +193,6 @@ def _final_stage(*, random_state=None, verbose=0):
192193
193194# simplification of sklearn's ColumnTransformer that encodes categoricals and passes through selected other columns
194195# but also supports get_feature_names with expected signature
195-
196-
197196class _ColumnTransformer (TransformerMixin ):
198197 def __init__ (self , categorical , passthrough ):
199198 self .categorical = categorical
@@ -208,22 +207,16 @@ def fit(self, X):
208207 handle_unknown = 'ignore' ).fit (cat_cols )
209208 else :
210209 self .has_cats = False
211- cont_cols = _safe_indexing (X , self .passthrough , axis = 1 )
212- if cont_cols .shape [1 ] > 0 :
213- self .has_conts = True
214- self .scaler = StandardScaler ().fit (cont_cols )
215- else :
216- self .has_conts = False
217210 self .d_x = X .shape [1 ]
218211 return self
219212
220213 def transform (self , X ):
221214 rest = _safe_indexing (X , self .passthrough , axis = 1 )
222- if self .has_conts :
223- rest = self .scaler .transform (rest )
224215 if self .has_cats :
225216 cats = self .one_hot_encoder .transform (_safe_indexing (X , self .categorical , axis = 1 ))
226- return np .hstack ((cats , rest ))
217+ # NOTE: we rely on the passthrough columns coming first in the concatenated X;W
218+ # when we pipeline scaling with our first stage models later, so the order here is important
219+ return np .hstack ((rest , cats ))
227220 else :
228221 return rest
229222
@@ -234,11 +227,32 @@ def get_feature_names(self, names=None):
234227 if self .has_cats :
235228 cats = self .one_hot_encoder .get_feature_names (
236229 _safe_indexing (names , self .categorical , axis = 0 ))
237- return np .concatenate ((cats , rest ))
230+ return np .concatenate ((rest , cats ))
238231 else :
239232 return rest
240233
241234
235+ # Wrapper to make sure that we get a deep copy of the contents instead of clone returning an untrained copy
236+ class _Wrapper :
237+ def __init__ (self , item ):
238+ self .item = item
239+
240+
241+ class _FrozenTransformer (TransformerMixin , BaseEstimator ):
242+ def __init__ (self , wrapper ):
243+ self .wrapper = wrapper
244+
245+ def fit (self , X , y ):
246+ return self
247+
248+ def transform (self , X ):
249+ return self .wrapper .item .transform (X )
250+
251+
252+ def _freeze (transformer ):
253+ return _FrozenTransformer (_Wrapper (transformer ))
254+
255+
242256# Convert python objects to (possibly nested) types that can easily be represented as literals
243257def _sanitize (obj ):
244258 if obj is None or isinstance (obj , (bool , int , str , float )):
@@ -310,6 +324,13 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
310324 else :
311325 cats = 'auto' # just leave the setting at the default otherwise
312326
327+ # the transformation logic here is somewhat tricky; we always need to encode the categorical columns,
328+ # whether they end up in X or in W. However, for the continuous columns, we want to scale them all
329+ # when running the first stage models, but don't want to scale the X columns when running the final model,
330+ # since then our coefficients will have odd units and our trees will also have decisions using those units.
331+ #
332+ # we achieve this by pipelining the X scaling with the Y and T models (with fixed scaling, not refitting)
333+
313334 hinds = heterogeneity_inds [feat_ind ]
314335 WX_transformer = ColumnTransformer ([('encode' , OneHotEncoder (drop = 'first' , sparse = False ),
315336 [ind for ind in categorical_inds
@@ -322,11 +343,14 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
322343 ('drop' , 'drop' , hinds ),
323344 ('drop_feat' , 'drop' , feat_ind )],
324345 remainder = StandardScaler ())
346+
347+ X_cont_inds = [ind for ind in hinds
348+ if ind != feat_ind and ind not in categorical_inds ]
349+
325350 # Use _ColumnTransformer instead of ColumnTransformer so we can get feature names
326351 X_transformer = _ColumnTransformer ([ind for ind in categorical_inds
327352 if ind != feat_ind and ind in hinds ],
328- [ind for ind in hinds
329- if ind != feat_ind and ind not in categorical_inds ])
353+ X_cont_inds )
330354
331355 # Controls are all other columns of X
332356 WX = WX_transformer .fit_transform (X )
@@ -340,6 +364,20 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
340364
341365 W = W_transformer .fit_transform (X )
342366 X_xf = X_transformer .fit_transform (X )
367+
368+ # HACK: this is slightly ugly because we rely on the fact that DML passes [X;W] to the first stage models
369+ # and so we can just peel the first columns off of that combined array for rescaling in the pipeline
370+ # TODO: consider addding an API to DML that allows for better understanding of how the nuisance inputs are
371+ # built, such as model_y_feature_names, model_t_feature_names, model_y_transformer, etc., so that this
372+ # becomes a valid approach to handling this
373+ X_scaler = ColumnTransformer ([('scale' , StandardScaler (),
374+ list (range (len (X_cont_inds ))))],
375+ remainder = 'passthrough' ).fit (np .hstack ([X_xf , W ])).named_transformers_ ['scale' ]
376+
377+ X_scaler_fixed = ColumnTransformer ([('scale' , _freeze (X_scaler ),
378+ list (range (len (X_cont_inds ))))],
379+ remainder = 'passthrough' )
380+
343381 if W .shape [1 ] == 0 :
344382 # array checking routines don't accept 0-width arrays
345383 W = None
@@ -358,14 +396,20 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
358396 random_state = random_state ,
359397 verbose = verbose ))
360398
399+ pipelined_model_t = Pipeline ([('scale' , X_scaler_fixed ),
400+ ('model' , model_t )])
401+
402+ pipelined_model_y = Pipeline ([('scale' , X_scaler_fixed ),
403+ ('model' , model_y )])
404+
361405 if X_xf is None and h_model == 'forest' :
362406 warnings .warn (f"Using a linear model instead of a forest model for feature '{ name } ' "
363407 "because forests don't support models with no heterogeneity indices" )
364408 h_model = 'linear'
365409
366410 if h_model == 'linear' :
367- est = LinearDML (model_y = model_y ,
368- model_t = model_t ,
411+ est = LinearDML (model_y = pipelined_model_y ,
412+ model_t = pipelined_model_t ,
369413 discrete_treatment = discrete_treatment ,
370414 fit_cate_intercept = True ,
371415 linear_first_stages = False ,
@@ -374,8 +418,8 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
374418 cv = cv ,
375419 mc_iters = mc_iters )
376420 elif h_model == 'forest' :
377- est = CausalForestDML (model_y = model_y ,
378- model_t = model_t ,
421+ est = CausalForestDML (model_y = pipelined_model_y ,
422+ model_t = pipelined_model_t ,
379423 discrete_treatment = discrete_treatment ,
380424 n_estimators = 4000 ,
381425 min_var_leaf_on_val = True ,
0 commit comments