Skip to content

Commit 4c02210

Browse files
argparse test
1 parent 270f28d commit 4c02210

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

causal_testing/json_front/json_class.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def set_path(self, json_path: str, dag_path: str, data_path: str):
6767
self.dag_path = Path(dag_path)
6868
self.data_path = Path(data_path)
6969

70-
7170
def set_variables(self, inputs: dict, outputs: dict, metas: dict):
7271

7372
"""Populate the Causal Variables
@@ -79,7 +78,6 @@ def set_variables(self, inputs: dict, outputs: dict, metas: dict):
7978
self.outputs = [Output(i["name"], i["type"]) for i in outputs]
8079
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
8180

82-
8381
def setup(self):
8482
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
8583
self.modelling_scenario = Scenario(self.inputs + self.outputs + self.metas, None)
@@ -236,7 +234,7 @@ def setup_logger(log_path: str):
236234
setup_log.addHandler(file_handler)
237235

238236
@staticmethod
239-
def get_args() -> argparse.Namespace:
237+
def get_args(test_args=None) -> argparse.Namespace:
240238
"""Command-line arguments
241239
242240
:return: parsed command line arguments
@@ -259,7 +257,6 @@ def get_args() -> argparse.Namespace:
259257
help="Specify path to file containing runtime data",
260258
required=True,
261259
)
262-
parser.add_argument("--data_path", help="Specify path to file containing runtime data", required=True)
263260
parser.add_argument(
264261
"--dag_path",
265262
help="Specify path to file containing the DAG, normally a .dot file",
@@ -270,4 +267,4 @@ def get_args() -> argparse.Namespace:
270267
help="Specify path to file containing JSON tests, normally a .json file",
271268
required=True,
272269
)
273-
return parser.parse_args()
270+
return parser.parse_args(test_args)

causal_testing/testing/estimators.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,6 @@ def estimate_unit_ate(self) -> float:
309309

310310
return unit_effect * self.treatment_values - unit_effect * self.control_values, [ci_low, ci_high]
311311

312-
313312
def estimate_ate(self) -> tuple[float, list[float, float], float]:
314313
"""Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
315314
by changing the treatment variable from the control value to the treatment value.
@@ -531,7 +530,6 @@ def estimate_cates(self) -> pd.DataFrame:
531530
# Obtain CATES and confidence intervals
532531
conditional_ates = model.effect(effect_modifier_df, T0=self.control_values, T1=self.treatment_values).flatten()
533532
[ci_low, ci_high] = model.effect_interval(
534-
535533
effect_modifier_df, T0=self.control_values, T1=self.treatment_values, alpha=0.05
536534
)
537535

tests/json_front_tests/test_json_class.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,14 @@ def test_set_metas(self):
5353
self.assertEqual(ctf_meta[0].name, self.json_class.metas[0].name)
5454
self.assertEqual(ctf_meta[0].datatype, self.json_class.metas[0].datatype)
5555

56+
def test_argparse(self):
57+
args = self.json_class.get_args(["--data_path=data.csv", "--dag_path=dag.dot", "--json_path=tests.json"])
58+
self.assertTrue(args.data_path)
59+
self.assertTrue(args.dag_path)
60+
self.assertTrue(args.json_path)
61+
5662
def tearDown(self) -> None:
5763
remove_temp_dir_if_existent()
5864

59-
6065
def populate_example():
6166
pass

0 commit comments

Comments
 (0)