1
1
"""Module containing classes to define and run causal surrogate assisted test cases"""
2
2
3
+ from abc import ABC
4
+ from dataclasses import dataclass
5
+ from typing import Callable , Any
6
+
3
7
from causal_testing .data_collection .data_collector import ObservationalDataCollector
4
8
from causal_testing .specification .causal_specification import CausalSpecification
5
9
from causal_testing .testing .base_test_case import BaseTestCase
6
10
from causal_testing .testing .estimators import Estimator , CubicSplineRegressionEstimator
7
11
8
- from dataclasses import dataclass
9
- from typing import Callable , Any
10
- from abc import ABC
11
-
12
12
13
13
@dataclass
14
14
class SimulationResult (ABC ):
15
15
"""Data class holding the data and result metadata of a simulation"""
16
+
16
17
data : dict
17
18
fault : bool
18
19
relationship : str
@@ -21,13 +22,14 @@ class SimulationResult(ABC):
21
22
@dataclass
22
23
class SearchFitnessFunction (ABC ):
23
24
"""Data class containing the Fitness function and related model"""
25
+
24
26
fitness_function : Any
25
27
surrogate_model : CubicSplineRegressionEstimator
26
28
27
29
28
30
class SearchAlgorithm :
29
31
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
30
- space to be searched"""
32
+ space to be searched"""
31
33
32
34
def generate_fitness_functions (self , surrogate_models : list [Estimator ]) -> list [SearchFitnessFunction ]:
33
35
"""Generates the fitness function of the search space
@@ -43,7 +45,7 @@ def search(self, fitness_functions: list[SearchFitnessFunction], specification:
43
45
44
46
class Simulator :
45
47
"""Class to be inherited with Simulator specific functions to start, shutdown and run the simulation with the give
46
- config file"""
48
+ config file"""
47
49
48
50
def startup (self , ** kwargs ):
49
51
"""Function that when run, initialises and opens the Simulator"""
@@ -62,35 +64,33 @@ class CausalSurrogateAssistedTestCase:
62
64
"""A class representing a single causal surrogate assisted test case."""
63
65
64
66
def __init__ (
65
- self ,
66
- specification : CausalSpecification ,
67
- search_algorithm : SearchAlgorithm ,
68
- simulator : Simulator ,
67
+ self ,
68
+ specification : CausalSpecification ,
69
+ search_algorithm : SearchAlgorithm ,
70
+ simulator : Simulator ,
69
71
):
70
72
self .specification = specification
71
73
self .search_algorithm = search_algorithm
72
74
self .simulator = simulator
73
75
74
76
def execute (
75
- self ,
76
- data_collector : ObservationalDataCollector ,
77
- max_executions : int = 200 ,
78
- custom_data_aggregator : Callable [[dict , dict ], dict ] = None ,
77
+ self ,
78
+ data_collector : ObservationalDataCollector ,
79
+ max_executions : int = 200 ,
80
+ custom_data_aggregator : Callable [[dict , dict ], dict ] = None ,
79
81
):
80
- """ For this specific test case, collect the data, run the simulator, check for faults and return the result
81
- and collected data
82
- :param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
83
- :param max_executions: Maximum number of executions
84
- :param custom_data_aggregator:
85
- :return: tuple containing SimulationResult or str, execution number and collected data """
82
+ """For this specific test case, collect the data, run the simulator, check for faults and return the result
83
+ and collected data
84
+ :param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
85
+ :param max_executions: Maximum number of executions
86
+ :param custom_data_aggregator:
87
+ :return: tuple containing SimulationResult or str, execution number and collected data"""
86
88
data_collector .collect_data ()
87
89
88
90
for i in range (max_executions ):
89
91
surrogate_models = self .generate_surrogates (self .specification , data_collector )
90
92
fitness_functions = self .search_algorithm .generate_fitness_functions (surrogate_models )
91
- candidate_test_case , _fitness , surrogate = self .search_algorithm .search (
92
- fitness_functions , self .specification
93
- )
93
+ candidate_test_case , _ , surrogate = self .search_algorithm .search (fitness_functions , self .specification )
94
94
95
95
self .simulator .startup ()
96
96
test_result = self .simulator .run_with_config (candidate_test_case )
@@ -116,9 +116,9 @@ def execute(
116
116
return "No fault found" , i + 1 , data_collector .data
117
117
118
118
def generate_surrogates (
119
- self , specification : CausalSpecification , data_collector : ObservationalDataCollector
119
+ self , specification : CausalSpecification , data_collector : ObservationalDataCollector
120
120
) -> list [SearchFitnessFunction ]:
121
- """ Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
121
+ """Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
122
122
:param specification: The Causal Specification (combination of Scenario and Causal Dag)
123
123
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
124
124
:return: A list of surrogate models
0 commit comments