Skip to content

Commit b5d2c0a

Browse files
Merge branch 'main' into update-readme
# Conflicts: # causal_testing/json_front/json_class.py # causal_testing/testing/estimators.py
2 parents 2b4839d + 6764fb3 commit b5d2c0a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+996
-761
lines changed

.github/ISSUE_TEMPLATE/bug_report.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
---
2+
name: Bug report
3+
about: Create a report to help us improve
4+
title: ''
5+
labels: ''
6+
assignees: ''
7+
8+
---
9+
10+
**Describe the bug**
11+
A clear and concise description of what the bug is.
12+
13+
**To Reproduce**
14+
Steps to reproduce the behavior:
15+
1. Go to '...'
16+
2. Click on '....'
17+
3. Scroll down to '....'
18+
4. See error
19+
20+
**Expected behavior**
21+
A clear and concise description of what you expected to happen.
22+
23+
**Screenshots**
24+
If applicable, add screenshots to help explain your problem.
25+
26+
**Desktop (please complete the following information):**
27+
- OS: [e.g. iOS]
28+
- Browser [e.g. chrome, safari]
29+
- Version [e.g. 22]
30+
31+
**Smartphone (please complete the following information):**
32+
- Device: [e.g. iPhone6]
33+
- OS: [e.g. iOS8.1]
34+
- Browser [e.g. stock browser, safari]
35+
- Version [e.g. 22]
36+
37+
**Additional context**
38+
Add any other context about the problem here.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
name: Feature request
3+
about: Suggest an idea for this project
4+
title: ''
5+
labels: ''
6+
assignees: ''
7+
8+
---
9+
10+
**Is your feature request related to a problem? Please describe.**
11+
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12+
13+
**Describe the solution you'd like**
14+
A clear and concise description of what you want to happen.
15+
16+
**Describe alternatives you've considered**
17+
A clear and concise description of any alternative solutions or features you've considered.
18+
19+
**Additional context**
20+
Add any other context or screenshots about the feature request here.

.github/workflows/publish-to-pypi.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
name: Publish python PyPI
22

3+
on:
4+
push:
5+
tags:
6+
- v*
7+
38
jobs:
49
build-release:
510
name: Build and publish PyPI
@@ -17,6 +22,8 @@ jobs:
1722
pip3 install .
1823
pip3 install .[pypi]
1924
pip3 install build
25+
pip3 install setuptools --upgrade
26+
pip3 install setuptools_scm
2027
- name: Build Package
2128
run: |
2229
python -m build --no-isolation

.github/workflows/publish-to-test-pypi.yaml

Lines changed: 0 additions & 32 deletions
This file was deleted.

.pylintrc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ disable=raw-checker-failed,
152152
useless-suppression,
153153
deprecated-pragma,
154154
use-symbolic-message-instead,
155+
logging-fstring-interpolation,
156+
import-error,
155157

156158
# Enable the message, report, category or checker with the given id(s). You can
157159
# either give multiple identifier separated by comma (,) or put this option
@@ -239,7 +241,9 @@ good-names=i,
239241
j,
240242
k,
241243
ex,
244+
df,
242245
Run,
246+
z3,
243247
_
244248

245249
# Good variable names regexes, separated by a comma. If names match any regex,

causal_testing/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
"""
2+
This is the CausalTestingFramework Module
3+
It contains 5 subpackages:
4+
data_collection
5+
generation
6+
json_front
7+
specification
8+
testing
9+
"""
10+
111
import logging
212

313
logger = logging.getLogger(__name__)

causal_testing/data_collection/data_collector.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""This module contains the DataCollector abstract class, as well as its concrete extensions: ExperimentalDataCollector
2+
and ObservationalDataCollector"""
3+
14
import logging
25
from abc import ABC, abstractmethod
36
from enum import Enum
@@ -35,11 +38,15 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
3538
"""
3639

3740
# Check positivity
38-
scenario_variables = set(self.scenario.variables)
41+
scenario_variables = set(self.scenario.variables) - {x.name for x in self.scenario.hidden_variables()}
3942

40-
if check_pos and not scenario_variables.issubset(data.columns):
43+
if check_pos and not (scenario_variables - {x.name for x in self.scenario.hidden_variables()}).issubset(
44+
set(data.columns)
45+
):
4146
missing_variables = scenario_variables - set(data.columns)
42-
raise IndexError(f"Positivity violation: missing data for variables {missing_variables}.")
47+
raise IndexError(
48+
f"Missing columns: missing data for variables {missing_variables}. Should they be marked as hidden?"
49+
)
4350

4451
# For each row, does it satisfy the constraints?
4552
solver = z3.Solver()
@@ -54,6 +61,7 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
5461
self.scenario.variables[var].z3
5562
== self.scenario.variables[var].z3_val(self.scenario.variables[var].z3, row[var])
5663
for var in self.scenario.variables
64+
if var in row
5765
]
5866
for c in model:
5967
solver.assert_and_track(c, f"model: {c}")
@@ -73,10 +81,7 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
7381
size_diff = len(data) - len(satisfying_data)
7482
if size_diff > 0:
7583
logger.warning(
76-
"Discarded %s/%s values due to constraint violations.\n" "For example%s",
77-
size_diff,
78-
len(data),
79-
unsat_core,
84+
f"Discarded {size_diff}/{len(data)} values due to constraint violations.\n For example {unsat_core}",
8085
)
8186
return satisfying_data
8287

@@ -122,22 +127,23 @@ def run_system_with_input_configuration(self, input_configuration: dict) -> pd.D
122127

123128

124129
class ObservationalDataCollector(DataCollector):
125-
"""A data collector that extracts data that is relevant to the specified scenario from a csv of execution data."""
130+
"""A data collector that extracts data that is relevant to the specified scenario from a dataframe of execution
131+
data."""
126132

127-
def __init__(self, scenario: Scenario, csv_path: str):
133+
def __init__(self, scenario: Scenario, data: pd.DataFrame):
128134
super().__init__(scenario)
129-
self.csv_path = csv_path
135+
self.data = data
130136

131137
def collect_data(self, **kwargs) -> pd.DataFrame:
132-
"""Read a csv containing execution data for the system-under-test into a pandas dataframe and filter to remove
138+
"""Read a pandas dataframe and filter to remove
133139
any data which is invalid for the scenario-under-test.
134140
135141
Data is invalid if it does not meet the constraints outlined in the scenario-under-test (Scenario).
136142
137143
:return: A pandas dataframe containing execution data that is valid for the scenario-under-test.
138144
"""
139145

140-
execution_data_df = pd.read_csv(self.csv_path, **kwargs)
146+
execution_data_df = self.data
141147
for meta in self.scenario.metas():
142148
meta.populate(execution_data_df)
143149
scenario_execution_data_df = self.filter_valid_data(execution_data_df)

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1+
"""This module contains the class AbstractCausalTestCase, which generates concrete test cases"""
2+
import itertools
13
import logging
4+
from enum import Enum
5+
from typing import Iterable
26

37
import lhsmdu
48
import pandas as pd
59
import z3
610
from scipy import stats
7-
import itertools
11+
812

913
from causal_testing.specification.scenario import Scenario
1014
from causal_testing.specification.variable import Variable
1115
from causal_testing.testing.causal_test_case import CausalTestCase
1216
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
1317
from causal_testing.testing.base_test_case import BaseTestCase
1418

15-
from enum import Enum
1619

1720
logger = logging.getLogger(__name__)
1821

@@ -25,6 +28,7 @@ class AbstractCausalTestCase:
2528
"""
2629

2730
def __init__(
31+
# pylint: disable=too-many-arguments
2832
self,
2933
scenario: Scenario,
3034
intervention_constraints: set[z3.ExprRef],
@@ -60,7 +64,9 @@ def __str__(self):
6064
)
6165
return f"When we apply intervention {self.intervention_constraints}, {outcome_string}"
6266

63-
def datapath(self):
67+
def datapath(self) -> str:
68+
"""Create and return the sanitised data path"""
69+
6470
def sanitise(string):
6571
return "".join([x for x in string if x.isalnum()])
6672

@@ -72,7 +78,11 @@ def sanitise(string):
7278
)
7379

7480
def _generate_concrete_tests(
75-
self, sample_size: int, rct: bool = False, seed: int = 0
81+
# pylint: disable=too-many-locals
82+
self,
83+
sample_size: int,
84+
rct: bool = False,
85+
seed: int = 0,
7686
) -> tuple[list[CausalTestCase], pd.DataFrame]:
7787
"""Generates a list of `num` concrete test cases.
7888
@@ -101,25 +111,7 @@ def _generate_concrete_tests(
101111
samples[var.name] = lhsmdu.inverseTransformSample(var.distribution, samples[var.name])
102112

103113
for index, row in samples.iterrows():
104-
optimizer = z3.Optimize()
105-
for c in self.scenario.constraints:
106-
optimizer.assert_and_track(c, str(c))
107-
for c in self.intervention_constraints:
108-
optimizer.assert_and_track(c, str(c))
109-
110-
for v in run_columns:
111-
optimizer.add_soft(
112-
self.scenario.variables[v].z3
113-
== self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])
114-
)
115-
116-
if optimizer.check() == z3.unsat:
117-
logger.warning(
118-
"Satisfiability of test case was unsat.\n" "Constraints \n %s \n Unsat core %s",
119-
optimizer,
120-
optimizer.unsat_core(),
121-
)
122-
model = optimizer.model()
114+
model = self._optimizer_model(run_columns, row)
123115

124116
base_test_case = BaseTestCase(
125117
treatment_variable=self.treatment_variable,
@@ -146,7 +138,7 @@ def _generate_concrete_tests(
146138
+ f"{constraints}\nUsing value {v.cast(model[v.z3])} instead in test\n{concrete_test}"
147139
)
148140

149-
if not any([vars(t) == vars(concrete_test) for t in concrete_tests]):
141+
if not any((vars(t) == vars(concrete_test) for t in concrete_tests)):
150142
concrete_tests.append(concrete_test)
151143
# Control run
152144
control_run = {
@@ -164,6 +156,7 @@ def _generate_concrete_tests(
164156
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])
165157

166158
def generate_concrete_tests(
159+
# pylint: disable=too-many-arguments, too-many-locals
167160
self,
168161
sample_size: int,
169162
target_ks_score: float = None,
@@ -197,12 +190,12 @@ def generate_concrete_tests(
197190

198191
pre_break = False
199192
for i in range(hard_max):
200-
concrete_tests_, runs_ = self._generate_concrete_tests(sample_size, rct, seed + i)
201-
for t_ in concrete_tests_:
202-
if not any([vars(t_) == vars(t) for t in concrete_tests]):
203-
concrete_tests.append(t_)
204-
runs = pd.concat([runs, runs_])
205-
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
193+
concrete_tests_temp, runs_temp = self._generate_concrete_tests(sample_size, rct, seed + i)
194+
for test in concrete_tests_temp:
195+
if not any((vars(test) == vars(t) for t in concrete_tests)):
196+
concrete_tests.append(test)
197+
runs = pd.concat([runs, runs_temp])
198+
assert concrete_tests_temp not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
206199

207200
control_configs = pd.DataFrame([{test.treatment_variable: test.control_value} for test in concrete_tests])
208201
ks_stats = {
@@ -230,7 +223,7 @@ def generate_concrete_tests(
230223
control_values = [test.control_value for test in concrete_tests]
231224
treatment_values = [test.treatment_value for test in concrete_tests]
232225

233-
if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(
226+
if self.treatment_variable.datatype is bool and {(True, False), (False, True)}.issubset(
234227
set(zip(control_values, treatment_values))
235228
):
236229
pre_break = True
@@ -244,7 +237,7 @@ def generate_concrete_tests(
244237
).issubset(zip(control_values, treatment_values)):
245238
pre_break = True
246239
break
247-
elif target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
240+
if target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
248241
pre_break = True
249242
break
250243

@@ -256,3 +249,30 @@ def generate_concrete_tests(
256249
len(concrete_tests),
257250
)
258251
return concrete_tests, runs
252+
253+
def _optimizer_model(self, run_columns: Iterable[str], row: pd.core.series) -> z3.Optimize:
254+
"""
255+
:param run_columns: A sorted list of Variable names from the scenario variables
256+
:param row: A pandas Series containing a row from the Samples dataframe
257+
:return: z3 optimize model with constraints tracked and soft constraints added
258+
:rtype: z3.Optimize
259+
"""
260+
optimizer = z3.Optimize()
261+
for c in self.scenario.constraints:
262+
optimizer.assert_and_track(c, str(c))
263+
for c in self.intervention_constraints:
264+
optimizer.assert_and_track(c, str(c))
265+
266+
for v in run_columns:
267+
optimizer.add_soft(
268+
self.scenario.variables[v].z3
269+
== self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])
270+
)
271+
272+
if optimizer.check() == z3.unsat:
273+
logger.warning(
274+
f"Satisfiability of test case was unsat.\n"
275+
f"Constraints \n {optimizer} \n Unsat core {optimizer.unsat_core()}",
276+
)
277+
model = optimizer.model()
278+
return model

0 commit comments

Comments
 (0)