Skip to content

Commit 6dd1012

Browse files
Merge branch 'main' into python-311-compatible
# Conflicts: # causal_testing/testing/estimators.py # tests/testing_tests/test_causal_test_case.py # tests/testing_tests/test_causal_test_suite.py # tests/testing_tests/test_estimators.py
2 parents f30ca8f + 0d9b1d7 commit 6dd1012

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+2325
-515
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
name: Continuous Integration Tests Draft PR (pytest)
2+
# This duplicate ci workflow is required so the badge in the README.md is not effected by draft PRs
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
8+
jobs:
9+
build:
10+
if: github.event.pull_request.draft == true
11+
name: Ex1 (${{ matrix.python-version }}, ${{ matrix.os }})
12+
runs-on: ${{ matrix.os }}
13+
strategy:
14+
matrix:
15+
os: ["ubuntu-latest", "windows-latest", "macos-latest"]
16+
python-version: ["3.9"]
17+
steps:
18+
- uses: actions/checkout@v2
19+
- name: Set up Python using Miniconda
20+
uses: conda-incubator/setup-miniconda@v2
21+
with:
22+
auto-update-conda: true
23+
python-version: ${{ matrix.python-version }}
24+
- name: Install package and dependencies
25+
run: |
26+
python --version
27+
pip install -e .
28+
pip install -e .[test]
29+
pip install pytest pytest-cov
30+
shell: bash -l {0}
31+
- name: Test with pytest
32+
run: |
33+
pytest --cov=causal_testing --cov-report=xml
34+
shell: bash -l {0}
35+
- name: "Upload coverage to Codecov"
36+
uses: codecov/codecov-action@v2
37+
with:
38+
fail_ci_if_error: true
39+
token: ${{ secrets.CODECOV_TOKEN }}

.github/workflows/ci-tests.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@ on:
44
pull_request:
55
branches:
66
- main
7+
types:
8+
- opened
9+
- synchronize
10+
- reopened
11+
- ready_for_review
712

813
jobs:
914
build:
15+
if: github.event.pull_request.draft == false # Filter out draft PRs
1016
name: Ex1 (${{ matrix.python-version }}, ${{ matrix.os }})
1117
runs-on: ${{ matrix.os }}
1218
strategy:

.github/workflows/figshare.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
name: Release to Figshare
2+
on:
3+
workflow_dispatch:
4+
release:
5+
types: [published]
6+
jobs:
7+
upload:
8+
runs-on: ubuntu-latest
9+
env:
10+
ARCHIVE_NAME: ${{ github.event.repository.name }}-${{ github.event.release.tag_name }}
11+
steps:
12+
- name: prepare-data-folder
13+
run : mkdir 'data'
14+
- name: download-archive
15+
run: |
16+
curl -sL "${{ github.event.release.zipball_url }}" > "$ARCHIVE_NAME".zip
17+
curl -sL "${{ github.event.release.tarball_url }}" > "$ARCHIVE_NAME".tar.gz
18+
- name: move-archive
19+
run: |
20+
mv "$ARCHIVE_NAME".zip data/
21+
mv "$ARCHIVE_NAME".tar.gz data/
22+
- name: upload-to-figshare
23+
uses: figshare/[email protected]
24+
with:
25+
FIGSHARE_TOKEN: ${{ secrets.FIGSHARE_TOKEN }}
26+
FIGSHARE_ENDPOINT: 'https://api.figshare.com/v2'
27+
FIGSHARE_ARTICLE_ID: 24427516
28+
DATA_DIR: 'data'

README.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1-
# Causal Testing Framework: A Causal Inference-Driven Software Testing Framework
1+
# Causal Testing Framework
2+
### A Causal Inference-Driven Software Testing Framework
23

3-
![example workflow](https://github.com/CITCOM-project/CausalTestingFramework/actions/workflows/ci-tests.yaml/badge.svg) [![codecov](https://codecov.io/gh/CITCOM-project/CausalTestingFramework/branch/main/graph/badge.svg?token=04ijFVrb4a)](https://codecov.io/gh/CITCOM-project/CausalTestingFramework) [![Documentation Status](https://readthedocs.org/projects/causal-testing-framework/badge/?version=latest)](https://causal-testing-framework.readthedocs.io/en/latest/?badge=latest)
4+
5+
[![Project Status: Active – The project has reached a stable, usable state and is being actively developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
6+
![example workflow](https://github.com/CITCOM-project/CausalTestingFramework/actions/workflows/ci-tests.yaml/badge.svg)
7+
[![codecov](https://codecov.io/gh/CITCOM-project/CausalTestingFramework/branch/main/graph/badge.svg?token=04ijFVrb4a)](https://codecov.io/gh/CITCOM-project/CausalTestingFramework)
8+
[![Documentation Status](https://readthedocs.org/projects/causal-testing-framework/badge/?version=latest)](https://causal-testing-framework.readthedocs.io/en/latest/?badge=latest)
9+
![Dynamic TOML Badge](https://img.shields.io/badge/dynamic/toml?url=https%3A%2F%2Fraw.githubusercontent.com%2FCITCOM-project%2FCausalTestingFramework%2Fmain%2Fpyproject.toml&query=%24.project%5B'requires-python'%5D&label=python)
10+
![PyPI - Version](https://img.shields.io/pypi/v/causal-testing-framework)
11+
[![DOI](https://t.ly/FCT1B)](https://orda.shef.ac.uk/articles/software/CITCOM_Software_Release/24427516)
12+
![GitHub License](https://img.shields.io/github/license/CITCOM-project/CausalTestingFramework)
413

514
Causal testing is a causal inference-driven framework for functional black-box testing. This framework utilises
615
graphical causal inference (CI) techniques for the specification and functional testing of software from a black-box
@@ -12,10 +21,9 @@ system-under-test that is expected to cause a change to some output(s).
1221

1322
![Causal Testing Workflow](images/workflow.png)
1423

15-
1624
## Installation
1725

18-
See the readthedocs site for [installation
26+
See the Read the Docs site for [installation
1927
instructions](https://causal-testing-framework.readthedocs.io/en/latest/installation.html).
2028

2129
## Documentation

causal_testing/json_front/json_class.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,6 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
301301
"""Create the necessary inputs for a single test case
302302
:param causal_test_case: The concrete test case to be executed
303303
:param test: Single JSON test definition stored in a mapping (dict)
304-
:param conditions: A list of conditions which should be applied to the
305-
data. Conditions should be in the query format detailed at
306-
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
307304
:returns:
308305
- estimation_model - Estimator instance for the test being run
309306
"""
@@ -323,11 +320,13 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
323320
minimal_adjustment_set = minimal_adjustment_set - {causal_test_case.treatment_variable}
324321
estimator_kwargs["adjustment_set"] = minimal_adjustment_set
325322

323+
estimator_kwargs["query"] = test["query"] if "query" in test else ""
326324
estimator_kwargs["treatment"] = causal_test_case.treatment_variable.name
327325
estimator_kwargs["treatment_value"] = causal_test_case.treatment_value
328326
estimator_kwargs["control_value"] = causal_test_case.control_value
329327
estimator_kwargs["outcome"] = causal_test_case.outcome_variable.name
330328
estimator_kwargs["effect_modifiers"] = causal_test_case.effect_modifier_configuration
329+
estimator_kwargs["df"] = self.data_collector.collect_data()
331330
estimator_kwargs["alpha"] = test["alpha"] if "alpha" in test else 0.05
332331

333332
estimation_model = test["estimator"](**estimator_kwargs)

causal_testing/specification/causal_dag.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def close_separator(
125125

126126

127127
class CausalDAG(nx.DiGraph):
128-
129128
"""A causal DAG is a directed acyclic graph in which nodes represent random variables and edges represent causality
130129
between a pair of random variables. We implement a CausalDAG as a networkx DiGraph with an additional check that
131130
ensures it is acyclic. A CausalDAG must be specified as a dot file.
@@ -500,11 +499,20 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
500499
return True
501500
return any((self.depends_on_outputs(n, scenario) for n in self.graph.predecessors(node)))
502501

503-
def identification(self, base_test_case: BaseTestCase):
502+
@staticmethod
503+
def remove_hidden_adjustment_sets(minimal_adjustment_sets: list[str], scenario: Scenario):
504+
"""Remove variables labelled as hidden from adjustment set(s)
505+
:param minimal_adjustment_sets: list of minimal adjustment set(s) to have hidden variables removed from
506+
:param scenario: The modelling scenario which informs the variables that are hidden
507+
"""
508+
return [adj for adj in minimal_adjustment_sets if all(not scenario.variables.get(x).hidden for x in adj)]
509+
510+
def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None):
504511
"""Identify and return the minimum adjustment set
505512
506513
:param base_test_case: A base test case instance containing the outcome_variable and the
507514
treatment_variable required for identification.
515+
:param scenario: The modelling scenario relating to the tests
508516
:return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
509517
estimate as opposed to a purely associational estimate.
510518
"""
@@ -520,6 +528,12 @@ def identification(self, base_test_case: BaseTestCase):
520528
else:
521529
raise ValueError("Causal effect should be 'total' or 'direct'")
522530

531+
if scenario is not None:
532+
minimal_adjustment_sets = self.remove_hidden_adjustment_sets(minimal_adjustment_sets, scenario)
533+
534+
if len(minimal_adjustment_sets) == 0:
535+
return set()
536+
523537
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
524538
return minimal_adjustment_set
525539

causal_testing/surrogate/__init__.py

Whitespace-only changes.
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""Module containing classes to define and run causal surrogate assisted test cases"""
2+
3+
from abc import ABC, abstractmethod
4+
from dataclasses import dataclass
5+
from typing import Callable
6+
7+
from causal_testing.data_collection.data_collector import ObservationalDataCollector
8+
from causal_testing.specification.causal_specification import CausalSpecification
9+
from causal_testing.testing.base_test_case import BaseTestCase
10+
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
11+
12+
13+
@dataclass
14+
class SimulationResult:
15+
"""Data class holding the data and result metadata of a simulation"""
16+
17+
data: dict
18+
fault: bool
19+
relationship: str
20+
21+
22+
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
23+
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
24+
space to be searched"""
25+
26+
@abstractmethod
27+
def search(
28+
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
29+
) -> list:
30+
"""Function which implements a search routine which searches for the optimal fitness value for the specified
31+
scenario
32+
:param surrogate_models: The surrogate models to be searched
33+
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
34+
35+
36+
class Simulator(ABC):
37+
"""Class to be inherited with Simulator specific functions to start, shutdown and run the simulation with the give
38+
config file"""
39+
40+
@abstractmethod
41+
def startup(self, **kwargs):
42+
"""Function that when run, initialises and opens the Simulator"""
43+
44+
@abstractmethod
45+
def shutdown(self, **kwargs):
46+
"""Function to safely exit and shutdown the Simulator"""
47+
48+
@abstractmethod
49+
def run_with_config(self, configuration: dict) -> SimulationResult:
50+
"""Run the simulator with the given configuration and return the results in the structure of a
51+
SimulationResult
52+
:param configuration: The configuration required to initialise the Simulation
53+
:return: Simulation results in the structure of the SimulationResult data class"""
54+
55+
56+
class CausalSurrogateAssistedTestCase:
57+
"""A class representing a single causal surrogate assisted test case."""
58+
59+
def __init__(
60+
self,
61+
specification: CausalSpecification,
62+
search_algorithm: SearchAlgorithm,
63+
simulator: Simulator,
64+
):
65+
self.specification = specification
66+
self.search_algorithm = search_algorithm
67+
self.simulator = simulator
68+
69+
def execute(
70+
self,
71+
data_collector: ObservationalDataCollector,
72+
max_executions: int = 200,
73+
custom_data_aggregator: Callable[[dict, dict], dict] = None,
74+
):
75+
"""For this specific test case, a search algorithm is used to find the most contradictory point in the input
76+
space which is, therefore, most likely to indicate incorrect behaviour. This cadidate test case is run against
77+
the simulator, checked for faults and the result returned with collected data
78+
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
79+
:param max_executions: Maximum number of simulator executions before exiting the search
80+
:param custom_data_aggregator:
81+
:return: tuple containing SimulationResult or str, execution number and collected data"""
82+
data_collector.collect_data()
83+
84+
for i in range(max_executions):
85+
surrogate_models = self.generate_surrogates(self.specification, data_collector)
86+
candidate_test_case, _, surrogate = self.search_algorithm.search(surrogate_models, self.specification)
87+
88+
self.simulator.startup()
89+
test_result = self.simulator.run_with_config(candidate_test_case)
90+
self.simulator.shutdown()
91+
92+
if custom_data_aggregator is not None:
93+
if data_collector.data is not None:
94+
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data)
95+
else:
96+
data_collector.data = data_collector.data.append(test_result.data, ignore_index=True)
97+
98+
if test_result.fault:
99+
print(
100+
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with "
101+
f"expected {surrogate.expected_relationship}."
102+
)
103+
test_result.relationship = (
104+
f"{surrogate.treatment} -> {surrogate.outcome} expected {surrogate.expected_relationship}"
105+
)
106+
return test_result, i + 1, data_collector.data
107+
108+
print("No fault found")
109+
return "No fault found", i + 1, data_collector.data
110+
111+
def generate_surrogates(
112+
self, specification: CausalSpecification, data_collector: ObservationalDataCollector
113+
) -> list[CubicSplineRegressionEstimator]:
114+
"""Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
115+
:param specification: The Causal Specification (combination of Scenario and Causal Dag)
116+
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
117+
:return: A list of surrogate models
118+
"""
119+
surrogate_models = []
120+
121+
for u, v in specification.causal_dag.graph.edges:
122+
edge_metadata = specification.causal_dag.graph.adj[u][v]
123+
if "included" in edge_metadata:
124+
from_var = specification.scenario.variables.get(u)
125+
to_var = specification.scenario.variables.get(v)
126+
base_test_case = BaseTestCase(from_var, to_var)
127+
128+
minimal_adjustment_set = specification.causal_dag.identification(base_test_case, specification.scenario)
129+
130+
surrogate = CubicSplineRegressionEstimator(
131+
u,
132+
0,
133+
0,
134+
minimal_adjustment_set,
135+
v,
136+
4,
137+
df=data_collector.data,
138+
expected_relationship=edge_metadata["expected"],
139+
)
140+
surrogate_models.append(surrogate)
141+
142+
return surrogate_models

0 commit comments

Comments
 (0)