Skip to content

Commit cd3b41d

Browse files
committed
Merge remote-tracking branch 'origin/surrogateassisted' into surrogateassisted
# Conflicts: # tests/surrogate_tests/test_causal_surrogate_assisted.py # tests/testing_tests/test_estimators.py
2 parents 162e28d + 04b377d commit cd3b41d

File tree

8 files changed

+498
-149
lines changed

8 files changed

+498
-149
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,9 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
531531
if scenario is not None:
532532
minimal_adjustment_sets = self.remove_hidden_adjustment_sets(minimal_adjustment_sets, scenario)
533533

534+
if len(minimal_adjustment_sets) == 0:
535+
return set()
536+
534537
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
535538
return minimal_adjustment_set
536539

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,61 @@
1+
"""Module containing classes to define and run causal surrogate assisted test cases"""
2+
3+
from abc import ABC, abstractmethod
4+
from dataclasses import dataclass
5+
from typing import Callable
6+
17
from causal_testing.data_collection.data_collector import ObservationalDataCollector
28
from causal_testing.specification.causal_specification import CausalSpecification
39
from causal_testing.testing.base_test_case import BaseTestCase
4-
from causal_testing.testing.estimators import Estimator, PolynomialRegressionEstimator
5-
6-
from dataclasses import dataclass
7-
from typing import Callable, Any
8-
from abc import ABC
10+
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
911

1012

1113
@dataclass
12-
class SimulationResult(ABC):
14+
class SimulationResult:
15+
"""Data class holding the data and result metadata of a simulation"""
16+
1317
data: dict
1418
fault: bool
1519
relationship: str
1620

1721

18-
@dataclass
19-
class SearchFitnessFunction(ABC):
20-
fitness_function: Any
21-
surrogate_model: PolynomialRegressionEstimator
22-
22+
class SearchAlgorithm(ABC):
23+
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
24+
space to be searched"""
2325

24-
class SearchAlgorithm:
25-
def generate_fitness_functions(self, surrogate_models: list[Estimator]) -> list[SearchFitnessFunction]:
26-
pass
26+
@abstractmethod
27+
def search(
28+
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
29+
) -> list:
30+
"""Function which implements a search routine which searches for the optimal fitness value for the specified
31+
scenario
32+
:param surrogate_models: The surrogate models to be searched
33+
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
2734

28-
def search(self, fitness_functions: list[SearchFitnessFunction], specification: CausalSpecification) -> list:
29-
pass
3035

36+
class Simulator(ABC):
37+
"""Class to be inherited with Simulator specific functions to start, shutdown and run the simulation with the give
38+
config file"""
3139

32-
class Simulator:
40+
@abstractmethod
3341
def startup(self, **kwargs):
34-
pass
42+
"""Function that when run, initialises and opens the Simulator"""
3543

44+
@abstractmethod
3645
def shutdown(self, **kwargs):
37-
pass
46+
"""Function to safely exit and shutdown the Simulator"""
3847

39-
def run_with_config(self, configuration) -> SimulationResult:
40-
pass
48+
@abstractmethod
49+
def run_with_config(self, configuration: dict) -> SimulationResult:
50+
"""Run the simulator with the given configuration and return the results in the structure of a
51+
SimulationResult
52+
:param configuration: The configuration required to initialise the Simulation
53+
:return: Simulation results in the structure of the SimulationResult data class"""
4154

4255

4356
class CausalSurrogateAssistedTestCase:
57+
"""A class representing a single causal surrogate assisted test case."""
58+
4459
def __init__(
4560
self,
4661
specification: CausalSpecification,
@@ -57,27 +72,33 @@ def execute(
5772
max_executions: int = 200,
5873
custom_data_aggregator: Callable[[dict, dict], dict] = None,
5974
):
75+
"""For this specific test case, a search algorithm is used to find the most contradictory point in the input
76+
space which is, therefore, most likely to indicate incorrect behaviour. This cadidate test case is run against
77+
the simulator, checked for faults and the result returned with collected data
78+
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
79+
:param max_executions: Maximum number of simulator executions before exiting the search
80+
:param custom_data_aggregator:
81+
:return: tuple containing SimulationResult or str, execution number and collected data"""
6082
data_collector.collect_data()
6183

6284
for i in range(max_executions):
6385
surrogate_models = self.generate_surrogates(self.specification, data_collector)
64-
fitness_functions = self.search_algorithm.generate_fitness_functions(surrogate_models)
65-
candidate_test_case, _fitness, surrogate = self.search_algorithm.search(
66-
fitness_functions, self.specification
67-
)
86+
candidate_test_case, _, surrogate = self.search_algorithm.search(surrogate_models, self.specification)
6887

6988
self.simulator.startup()
7089
test_result = self.simulator.run_with_config(candidate_test_case)
7190
self.simulator.shutdown()
7291

7392
if custom_data_aggregator is not None:
74-
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data)
93+
if data_collector.data is not None:
94+
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data)
7595
else:
7696
data_collector.data = data_collector.data.append(test_result.data, ignore_index=True)
7797

7898
if test_result.fault:
7999
print(
80-
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with expected {surrogate.expected_relationship}."
100+
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with "
101+
f"expected {surrogate.expected_relationship}."
81102
)
82103
test_result.relationship = (
83104
f"{surrogate.treatment} -> {surrogate.outcome} expected {surrogate.expected_relationship}"
@@ -89,7 +110,12 @@ def execute(
89110

90111
def generate_surrogates(
91112
self, specification: CausalSpecification, data_collector: ObservationalDataCollector
92-
) -> list[SearchFitnessFunction]:
113+
) -> list[CubicSplineRegressionEstimator]:
114+
"""Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
115+
:param specification: The Causal Specification (combination of Scenario and Causal Dag)
116+
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
117+
:return: A list of surrogate models
118+
"""
93119
surrogate_models = []
94120

95121
for u, v in specification.causal_dag.graph.edges:
@@ -101,7 +127,7 @@ def generate_surrogates(
101127

102128
minimal_adjustment_set = specification.causal_dag.identification(base_test_case, specification.scenario)
103129

104-
surrogate = PolynomialRegressionEstimator(
130+
surrogate = CubicSplineRegressionEstimator(
105131
u,
106132
0,
107133
0,
Lines changed: 64 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1-
from causal_testing.specification.causal_specification import CausalSpecification
2-
from causal_testing.testing.estimators import Estimator, PolynomialRegressionEstimator
3-
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm, SearchFitnessFunction
1+
"""Module containing implementation of search algorithm for surrogate search """
2+
# pylint: disable=cell-var-from-loop
3+
# Fitness functions are required to be iteratively defined, including all variables within.
44

5-
from pygad import GA
65
from operator import itemgetter
6+
from pygad import GA
7+
8+
from causal_testing.specification.causal_specification import CausalSpecification
9+
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
10+
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm
711

812

913
class GeneticSearchAlgorithm(SearchAlgorithm):
14+
"""Implementation of SearchAlgorithm class. Implements genetic search algorithm for surrogate models."""
15+
1016
def __init__(self, delta=0.05, config: dict = None) -> None:
1117
super().__init__()
1218

@@ -15,91 +21,93 @@ def __init__(self, delta=0.05, config: dict = None) -> None:
1521
self.contradiction_functions = {
1622
"positive": lambda x: -1 * x,
1723
"negative": lambda x: x,
18-
"no_effect": lambda x: abs(x),
24+
"no_effect": abs,
1925
"some_effect": lambda x: abs(1 / x),
2026
}
2127

22-
def generate_fitness_functions(
23-
self, surrogate_models: list[PolynomialRegressionEstimator]
24-
) -> list[SearchFitnessFunction]:
25-
fitness_functions = []
28+
def search(
29+
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
30+
) -> list:
31+
solutions = []
2632

2733
for surrogate in surrogate_models:
2834
contradiction_function = self.contradiction_functions[surrogate.expected_relationship]
2935

30-
# The returned fitness function after including required variables into the function's scope
31-
def fitness_function(_ga, solution, idx):
36+
# The GA fitness function after including required variables into the function's scope
37+
# Unused arguments are required for pygad's fitness function signature
38+
def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
3239
surrogate.control_value = solution[0] - self.delta
3340
surrogate.treatment_value = solution[0] + self.delta
3441

35-
adjustment_dict = dict()
42+
adjustment_dict = {}
3643
for i, adjustment in enumerate(surrogate.adjustment_set):
3744
adjustment_dict[adjustment] = solution[i + 1]
3845

3946
ate = surrogate.estimate_ate_calculated(adjustment_dict)
4047

4148
return contradiction_function(ate)
4249

43-
search_fitness_function = SearchFitnessFunction(fitness_function, surrogate)
44-
45-
fitness_functions.append(search_fitness_function)
46-
47-
return fitness_functions
48-
49-
def search(self, fitness_functions: list[SearchFitnessFunction], specification: CausalSpecification) -> list:
50-
solutions = []
51-
52-
for fitness_function in fitness_functions:
53-
var_space = dict()
54-
var_space[fitness_function.surrogate_model.treatment] = dict()
55-
for adj in fitness_function.surrogate_model.adjustment_set:
56-
var_space[adj] = dict()
57-
58-
for relationship in list(specification.scenario.constraints):
59-
rel_split = str(relationship).split(" ")
60-
61-
if rel_split[1] == ">=":
62-
var_space[rel_split[0]]["low"] = int(rel_split[2])
63-
elif rel_split[1] == "<=":
64-
var_space[rel_split[0]]["high"] = int(rel_split[2])
65-
66-
gene_space = []
67-
gene_space.append(var_space[fitness_function.surrogate_model.treatment])
68-
for adj in fitness_function.surrogate_model.adjustment_set:
69-
gene_space.append(var_space[adj])
70-
71-
gene_types = []
72-
gene_types.append(specification.scenario.variables.get(fitness_function.surrogate_model.treatment).datatype)
73-
for adj in fitness_function.surrogate_model.adjustment_set:
74-
gene_types.append(specification.scenario.variables.get(adj).datatype)
50+
gene_types, gene_space = self.create_gene_types(surrogate, specification)
7551

7652
ga = GA(
7753
num_generations=200,
7854
num_parents_mating=4,
79-
fitness_func=fitness_function.fitness_function,
55+
fitness_func=fitness_function,
8056
sol_per_pop=10,
81-
num_genes=1 + len(fitness_function.surrogate_model.adjustment_set),
57+
num_genes=1 + len(surrogate.adjustment_set),
8258
gene_space=gene_space,
8359
gene_type=gene_types,
8460
)
8561

8662
if self.config is not None:
8763
for k, v in self.config.items():
8864
if k == "gene_space":
89-
raise Exception(
90-
"Gene space should not be set through config. This is generated from the causal specification"
65+
raise ValueError(
66+
"Gene space should not be set through config. This is generated from the causal "
67+
"specification"
9168
)
9269
setattr(ga, k, v)
9370

9471
ga.run()
95-
solution, fitness, _idx = ga.best_solution()
72+
solution, fitness, _ = ga.best_solution()
9673

97-
solution_dict = dict()
98-
solution_dict[fitness_function.surrogate_model.treatment] = solution[0]
99-
for idx, adj in enumerate(fitness_function.surrogate_model.adjustment_set):
74+
solution_dict = {}
75+
solution_dict[surrogate.treatment] = solution[0]
76+
for idx, adj in enumerate(surrogate.adjustment_set):
10077
solution_dict[adj] = solution[idx + 1]
101-
solutions.append((solution_dict, fitness, fitness_function.surrogate_model))
78+
solutions.append((solution_dict, fitness, surrogate))
79+
80+
return max(solutions, key=itemgetter(1)) # This can be done better with fitness normalisation between edges
81+
82+
@staticmethod
83+
def create_gene_types(
84+
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
85+
) -> tuple[list, list]:
86+
"""Generate the gene_types and gene_space for a given fitness function and specification
87+
:param surrogate_model: Instance of a CubicSplineRegressionEstimator
88+
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
89+
90+
var_space = {}
91+
var_space[surrogate_model.treatment] = {}
92+
for adj in surrogate_model.adjustment_set:
93+
var_space[adj] = {}
94+
95+
for relationship in list(specification.scenario.constraints):
96+
rel_split = str(relationship).split(" ")
97+
98+
if rel_split[0] in var_space:
99+
if rel_split[1] == ">=":
100+
var_space[rel_split[0]]["low"] = int(rel_split[2])
101+
elif rel_split[1] == "<=":
102+
var_space[rel_split[0]]["high"] = int(rel_split[2])
103+
104+
gene_space = []
105+
gene_space.append(var_space[surrogate_model.treatment])
106+
for adj in surrogate_model.adjustment_set:
107+
gene_space.append(var_space[adj])
102108

103-
return max(
104-
solutions, key=itemgetter(1)
105-
) # TODO This can be done better with fitness normalisation between edges
109+
gene_types = []
110+
gene_types.append(specification.scenario.variables.get(surrogate_model.treatment).datatype)
111+
for adj in surrogate_model.adjustment_set:
112+
gene_types.append(specification.scenario.variables.get(adj).datatype)
113+
return gene_types, gene_space

0 commit comments

Comments
 (0)