Skip to content

Commit 1b43da6

Browse files
committed
Merge branch 'main' of github.com:CITCOM-project/CausalTestingFramework into json_read_multiple_files
2 parents 555e830 + 505cca6 commit 1b43da6

File tree

5 files changed

+25
-11
lines changed

5 files changed

+25
-11
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,15 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
3838
"""
3939

4040
# Check positivity
41-
scenario_variables = set(self.scenario.variables)
41+
scenario_variables = set(self.scenario.variables) - {x.name for x in self.scenario.hidden_variables()}
4242

43-
if check_pos and not scenario_variables.issubset(data.columns):
43+
if check_pos and not (scenario_variables - {x.name for x in self.scenario.hidden_variables()}).issubset(
44+
set(data.columns)
45+
):
4446
missing_variables = scenario_variables - set(data.columns)
45-
raise IndexError(f"Positivity violation: missing data for variables {missing_variables}.")
47+
raise IndexError(
48+
f"Missing columns: missing data for variables {missing_variables}. Should they be marked as hidden?"
49+
)
4650

4751
# For each row, does it satisfy the constraints?
4852
solver = z3.Solver()
@@ -57,6 +61,7 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
5761
self.scenario.variables[var].z3
5862
== self.scenario.variables[var].z3_val(self.scenario.variables[var].z3, row[var])
5963
for var in self.scenario.variables
64+
if var in row
6065
]
6166
for c in model:
6267
solver.assert_and_track(c, f"model: {c}")

causal_testing/json_front/json_class.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,6 @@ class CausalVariables:
300300
metas: list[Meta]
301301

302302
def __init__(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
303-
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
304-
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
305-
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
303+
self.inputs = [Input(**i) for i in inputs]
304+
self.outputs = [Output(**o) for o in outputs]
305+
self.metas = [Meta(**m) for m in metas] if metas else []

causal_testing/specification/scenario.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ def metas(self) -> set[Meta]:
142142
"""
143143
return self.variables_of_type(Meta)
144144

145+
def hidden_variables(self) -> set[Variable]:
146+
"""Get the set of hidden variables
147+
148+
:return The variables marked as hidden.
149+
:rtype: {Variable}
150+
"""
151+
return {v for v in self.variables.values() if v.hidden}
152+
145153
def add_variable(self, v: Variable) -> None:
146154
"""Add variable to variables attribute
147155
:param v: Variable to be added

causal_testing/specification/variable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,19 @@ class Variable(ABC):
6060
:attr name:
6161
:attr datatype:
6262
:attr distribution:
63-
63+
:attr hidden:
6464
"""
6565

6666
name: str
6767
datatype: T
6868
distribution: rv_generic
6969

70-
def __init__(self, name: str, datatype: T, distribution: rv_generic = None):
70+
def __init__(self, name: str, datatype: T, distribution: rv_generic = None, hidden: bool = False):
7171
self.name = name
7272
self.datatype = datatype
7373
self.z3 = z3_types(datatype)(name)
7474
self.distribution = distribution
75+
self.hidden = hidden
7576

7677
def __repr__(self):
7778
return f"{self.typestring()}: {self.name}::{self.datatype.__name__}"

tests/json_front_tests/test_json_class.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def setUp(self) -> None:
3131
self.data_path = [str(test_data_dir_path / data_file_name)]
3232
self.json_class = JsonUtility("logs.log")
3333
self.example_distribution = scipy.stats.uniform(1, 10)
34-
self.input_dict_list = [{"name": "test_input", "type": float, "distribution": self.example_distribution}]
35-
self.output_dict_list = [{"name": "test_output", "type": float}]
36-
self.meta_dict_list = [{"name": "test_meta", "type": float, "populate": populate_example}]
34+
self.input_dict_list = [{"name": "test_input", "datatype": float, "distribution": self.example_distribution}]
35+
self.output_dict_list = [{"name": "test_output", "datatype": float}]
36+
self.meta_dict_list = [{"name": "test_meta", "datatype": float, "populate": populate_example}]
3737
self.json_class.set_variables(self.input_dict_list, self.output_dict_list, None)
3838
self.json_class.set_paths(self.json_path, self.dag_path, self.data_path)
3939

0 commit comments

Comments
 (0)