1212from sklearn .model_selection import train_test_split
1313
1414
15+ sys .path .insert (0 , os .getcwd ())
16+ import causaltune # noqa: E402
17+
1518from causaltune import CausalTune
1619from causaltune .data_utils import CausalityDataset
1720from causaltune .models .passthrough import passthrough_model
@@ -112,6 +115,7 @@ def run_experiment(
112115 estimators : List [str ],
113116 dataset_path : str ,
114117 use_ray : bool ,
118+ propensity_automl_estimators : Optional [List [str ]] = None ,
115119):
116120 # Process datasets
117121 data_sets = {}
@@ -125,6 +129,7 @@ def run_experiment(
125129 name = " " .join (parts [1 :])
126130 file_path = f"{ dataset_path } /{ size } /{ name } .pkl"
127131 data_sets [f"{ size } { name } " ] = load_dataset (file_path )
132+ run_kind = dataset .split ("_" )[1 ]
128133
129134 out_dir = f"../EXPERIMENT_RESULTS_{ args .identifier } "
130135 os .makedirs (out_dir , exist_ok = True )
@@ -136,24 +141,22 @@ def run_experiment(
136141 already_running = False
137142 if use_ray :
138143 try :
139- runner = ray .get_actor ("TaskRunner" )
144+ runner = ray .get_actor (f "TaskRunner { run_kind } " )
140145 print ("\n " * 4 )
141146 print (
142147 "!!! Found an existing detached TaskRunner. Will assume the tasks have already been submitted."
143148 )
144149 print (
145- "!!! If you want to re-run the experiments from scratch, "
146- 'run ray.kill(ray.get_actor("TaskRunner", namespace="{}")) or recreate the cluster.' .format (
147- RAY_NAMESPACE
148- )
150+ f"!!! If you want to re-run the experiments from scratch, "
151+ 'run ray.kill(ray.get_actor("TaskRunner {run_kind}", namespace="{RAY_NAMESPACE}")) or recreate the cluster.'
149152 )
150153 print ("\n " * 4 )
151154 already_running = True
152155 except ValueError :
153156 print ("Ray: no detached TaskRunner found, creating..." )
154157 # This thing will be alive even if the host program exits
155- # Must be killed explicitly: ray.kill(ray.get_actor("TaskRunner"))
156- runner = TaskRunner .options (name = "TaskRunner" , lifetime = "detached" ).remote ()
158+ # Must be killed explicitly: ray.kill(ray.get_actor(f "TaskRunner {run_kind} "))
159+ runner = TaskRunner .options (name = f "TaskRunner { run_kind } " , lifetime = "detached" ).remote ()
157160
158161 out = []
159162 if not already_running :
@@ -190,6 +193,7 @@ def run_experiment(
190193 args .components_time_budget ,
191194 out_fn ,
192195 estimators ,
196+ propensity_automl_estimators ,
193197 )
194198 )
195199 else :
@@ -202,6 +206,7 @@ def run_experiment(
202206 args .components_time_budget ,
203207 out_fn ,
204208 estimators ,
209+ propensity_automl_estimators ,
205210 )
206211 out .append (results )
207212
@@ -238,6 +243,7 @@ def run_batch(
238243 estimators : List [str ],
239244 dataset_path : str ,
240245 use_ray : bool = False ,
246+ propensity_automl_estimators : Optional [List [str ]] = None ,
241247):
242248 args = parse_arguments ()
243249 args .identifier = identifier
@@ -255,12 +261,19 @@ def run_batch(
255261 # Assuming we port-mapped already by running ray dashboard
256262 ray .init (
257263 "ray://localhost:10001" ,
258- runtime_env = {"working_dir" : "." , "pip" : ["causaltune" , "catboost" , "ray[tune]" ]},
264+ runtime_env = {
265+ "working_dir" : "." ,
266+ "pip" : ["causaltune" , "catboost" , "ray[tune]" , "flaml[blendsearch]" ],
267+ },
259268 namespace = RAY_NAMESPACE ,
260269 )
261270
262271 out_dir = run_experiment (
263- args , estimators = estimators , dataset_path = dataset_path , use_ray = use_ray
272+ args ,
273+ estimators = estimators ,
274+ dataset_path = dataset_path ,
275+ use_ray = use_ray ,
276+ propensity_automl_estimators = propensity_automl_estimators ,
264277 )
265278 return out_dir
266279
@@ -275,8 +288,8 @@ class TaskRunner:
275288 def __init__ (self ):
276289 self .futures = {}
277290
278- def remote_single_run (self , * args ):
279- ref = remote_single_run .remote (* args )
291+ def remote_single_run (self , * args , ** kwargs ):
292+ ref = remote_single_run .remote (* args , ** kwargs )
280293 self .futures [ref .hex ()] = ref
281294 return ref .hex ()
282295
@@ -310,6 +323,7 @@ def single_run(
310323 components_time_budget : int ,
311324 out_fn : str ,
312325 estimators : List [str ],
326+ propensity_automl_estimators : Optional [List [str ]] = None ,
313327 outcome_model : str = "auto" ,
314328 i_run : int = 1 ,
315329):
@@ -342,6 +356,7 @@ def single_run(
342356 store_all_estimators = True ,
343357 propensity_model = propensity_model ,
344358 outcome_model = outcome_model ,
359+ propensity_automl_estimators = propensity_automl_estimators ,
345360 use_ray = False ,
346361 )
347362
0 commit comments