Skip to content

Commit 46216aa

Browse files
committed
Fix: final linting
1 parent 42d133f commit 46216aa

File tree

1 file changed

+29
-22
lines changed

1 file changed

+29
-22
lines changed

dafni/main_dafni.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
import os
1+
"""
2+
3+
Entrypoint script to run the causal testing framework on DAFNI
4+
5+
"""
6+
27
from pathlib import Path
38
import argparse
49
import json
@@ -14,7 +19,6 @@ class ValidationError(Exception):
1419
"""
1520
Custom class to capture validation errors in this script
1621
"""
17-
pass
1822

1923

2024
def get_args(test_args=None) -> argparse.Namespace:
@@ -24,20 +28,22 @@ def get_args(test_args=None) -> argparse.Namespace:
2428
:returns:
2529
- argparse.Namespace - A Namsespace consisting of the arguments to this script
2630
"""
27-
parser = argparse.ArgumentParser(description="A script for running the causal testing famework on DAFNI.")
31+
parser = argparse.ArgumentParser(description="A script for running the CTF on DAFNI.")
2832

2933
parser.add_argument(
3034
"--data_path", required=True,
3135
help="Path to the input runtime data (.csv)", nargs="+")
3236

3337
parser.add_argument('--tests_path', required=True,
34-
help='Path to the input configuration file containing the causal tests (.json)')
38+
help='Input configuration file path '
39+
'containing the causal tests (.json)')
3540

3641
parser.add_argument('--variables_path', required=True,
37-
help='Path to the input configuration file containing the predefined variables (.json)')
42+
help='Input configuration file path '
43+
'containing the predefined variables (.json)')
3844

3945
parser.add_argument("--dag_path", required=True,
40-
help="Path to the input file containing a valid DAG (.dot). "
46+
help="Input configuration file path containing a valid DAG (.dot). "
4147
"Note: this must be supplied if the --tests argument isn't provided.")
4248

4349
parser.add_argument('--output_path', required=False, help='Path to the output directory.')
@@ -81,7 +87,7 @@ def get_args(test_args=None) -> argparse.Namespace:
8187
return args
8288

8389

84-
def read_variables(variables_path: Path) -> dict:
90+
def read_variables(variables_path: Path) -> FileNotFoundError | dict:
8591
"""
8692
Function to read the variables.json file specified by the user
8793
:param variables_path: A Path object of the user-specified file path
@@ -90,44 +96,43 @@ def read_variables(variables_path: Path) -> dict:
9096
"""
9197
if not variables_path.exists() or variables_path.is_dir():
9298

93-
raise FileNotFoundError
94-
9599
print(f"JSON file not found at the specified location: {variables_path}")
96100

97-
else:
101+
raise FileNotFoundError
98102

99-
with variables_path.open('r') as file:
103+
with variables_path.open('r') as file:
100104

101-
inputs = json.load(file)
105+
inputs = json.load(file)
102106

103-
return inputs
107+
return inputs
104108

105109

106110
def validate_variables(data_dict: dict) -> tuple:
107111
"""
108112
Function to validate the variables defined in the causal tests
109113
:param data_dict: A dictionary consisting of the pre-defined variables for the causal tests
110114
:returns:
111-
- tuple - Tuple consisting of the inputs, outputs and constraints to pass into the modelling scenario
115+
- Tuple containing the inputs, outputs and constraints to pass into the modelling scenario
112116
"""
113117
if data_dict["variables"]:
114118

115119
variables = data_dict["variables"]
116120

117-
inputs = [Input(variable["name"], eval(variable["datatype"])) for variable in variables if
121+
inputs = [Input(variable["name"], eval(variable["datatype"]))
122+
for variable in variables if
118123
variable["typestring"] == "Input"]
119124

120-
outputs = [Output(variable["name"], eval(variable["datatype"])) for variable in variables if
125+
outputs = [Output(variable["name"], eval(variable["datatype"]))
126+
for variable in variables if
121127
variable["typestring"] == "Output"]
122128

123129
constraints = set()
124130

125-
for variable, _inputs in zip(variables, inputs):
131+
for variable, input_var in zip(variables, inputs):
126132

127133
if "constraint" in variable:
128134

129-
constraints.add(_inputs.z3 == variable["constraint"])
130-
135+
constraints.add(input_var.z3 == variable["constraint"])
131136
else:
132137

133138
raise ValidationError("Cannot find the variables defined by the causal tests.")
@@ -180,7 +185,8 @@ def main():
180185
json_utility.setup(scenario=modelling_scenario, data=data_frame)
181186

182187
# Step 7: Run the causal tests
183-
test_outcomes = json_utility.run_json_tests(effects=expected_outcome_effects, mutates={}, estimators=estimators,
188+
test_outcomes = json_utility.run_json_tests(effects=expected_outcome_effects,
189+
mutates={}, estimators=estimators,
184190
f_flag=args.f)
185191

186192
# Step 8: Update, print and save the final outputs
@@ -196,7 +202,7 @@ def main():
196202
test["result"].pop("control_value")
197203

198204

199-
with open(args.output_path, "w") as f:
205+
with open(args.output_path, "w", encoding="utf-8") as f:
200206

201207
print(json.dumps(test_outcomes, indent=2), file=f)
202208

@@ -208,7 +214,8 @@ def main():
208214

209215
else:
210216

211-
print(f"Execution successful. Output file saved at {Path(args.output_path).parent.resolve()}")
217+
print(f"Execution successful. "
218+
f"Output file saved at {Path(args.output_path).parent.resolve()}")
212219

213220

214221
if __name__ == "__main__":

0 commit comments

Comments
 (0)