diff --git a/causal_testing/surrogate/causal_surrogate_assisted.py b/causal_testing/surrogate/causal_surrogate_assisted.py index c30b2086..c67d4e4b 100644 --- a/causal_testing/surrogate/causal_surrogate_assisted.py +++ b/causal_testing/surrogate/causal_surrogate_assisted.py @@ -9,6 +9,7 @@ from causal_testing.testing.base_test_case import BaseTestCase from causal_testing.testing.estimators import CubicSplineRegressionEstimator + @dataclass class SimulationResult: """Data class holding the data and result metadata of a simulation""" @@ -19,7 +20,7 @@ class SimulationResult: def to_dataframe(self) -> pd.DataFrame: """Convert the simulation result data to a pandas DataFrame""" - data_as_lists = {k: v if isinstance(v, list) else [v] for k,v in self.data.items()} + data_as_lists = {k: v if isinstance(v, list) else [v] for k, v in self.data.items()} return pd.DataFrame(data_as_lists) diff --git a/causal_testing/testing/estimators.py b/causal_testing/testing/estimators.py index 3920e1f2..ffe76387 100644 --- a/causal_testing/testing/estimators.py +++ b/causal_testing/testing/estimators.py @@ -823,7 +823,7 @@ def estimate_hazard_ratio(self): # IPCW step 4: Use these weights in a weighted analysis of the outcome model # Estimate the KM graph and IPCW hazard ratio using Cox regression. - cox_ph = CoxPHFitter() + cox_ph = CoxPHFitter(alpha=self.alpha) cox_ph.fit( df=preprocessed_data, duration_col="tout",