Skip to content

Commit 1e22537

Browse files
remaining linting + black for causal_surrogate_assisted.py
1 parent 7692084 commit 1e22537

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
"""Module containing classes to define and run causal surrogate assisted test cases"""
22

3+
from abc import ABC
4+
from dataclasses import dataclass
5+
from typing import Callable, Any
6+
37
from causal_testing.data_collection.data_collector import ObservationalDataCollector
48
from causal_testing.specification.causal_specification import CausalSpecification
59
from causal_testing.testing.base_test_case import BaseTestCase
610
from causal_testing.testing.estimators import Estimator, CubicSplineRegressionEstimator
711

8-
from dataclasses import dataclass
9-
from typing import Callable, Any
10-
from abc import ABC
11-
1212

1313
@dataclass
1414
class SimulationResult(ABC):
1515
"""Data class holding the data and result metadata of a simulation"""
16+
1617
data: dict
1718
fault: bool
1819
relationship: str
@@ -21,13 +22,14 @@ class SimulationResult(ABC):
2122
@dataclass
2223
class SearchFitnessFunction(ABC):
2324
"""Data class containing the Fitness function and related model"""
25+
2426
fitness_function: Any
2527
surrogate_model: CubicSplineRegressionEstimator
2628

2729

2830
class SearchAlgorithm:
2931
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
30-
space to be searched"""
32+
space to be searched"""
3133

3234
def generate_fitness_functions(self, surrogate_models: list[Estimator]) -> list[SearchFitnessFunction]:
3335
"""Generates the fitness function of the search space
@@ -43,7 +45,7 @@ def search(self, fitness_functions: list[SearchFitnessFunction], specification:
4345

4446
class Simulator:
4547
"""Class to be inherited with Simulator specific functions to start, shutdown and run the simulation with the give
46-
config file"""
48+
config file"""
4749

4850
def startup(self, **kwargs):
4951
"""Function that when run, initialises and opens the Simulator"""
@@ -62,35 +64,33 @@ class CausalSurrogateAssistedTestCase:
6264
"""A class representing a single causal surrogate assisted test case."""
6365

6466
def __init__(
65-
self,
66-
specification: CausalSpecification,
67-
search_algorithm: SearchAlgorithm,
68-
simulator: Simulator,
67+
self,
68+
specification: CausalSpecification,
69+
search_algorithm: SearchAlgorithm,
70+
simulator: Simulator,
6971
):
7072
self.specification = specification
7173
self.search_algorithm = search_algorithm
7274
self.simulator = simulator
7375

7476
def execute(
75-
self,
76-
data_collector: ObservationalDataCollector,
77-
max_executions: int = 200,
78-
custom_data_aggregator: Callable[[dict, dict], dict] = None,
77+
self,
78+
data_collector: ObservationalDataCollector,
79+
max_executions: int = 200,
80+
custom_data_aggregator: Callable[[dict, dict], dict] = None,
7981
):
80-
""" For this specific test case, collect the data, run the simulator, check for faults and return the result
81-
and collected data
82-
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
83-
:param max_executions: Maximum number of executions
84-
:param custom_data_aggregator:
85-
:return: tuple containing SimulationResult or str, execution number and collected data """
82+
"""For this specific test case, collect the data, run the simulator, check for faults and return the result
83+
and collected data
84+
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
85+
:param max_executions: Maximum number of executions
86+
:param custom_data_aggregator:
87+
:return: tuple containing SimulationResult or str, execution number and collected data"""
8688
data_collector.collect_data()
8789

8890
for i in range(max_executions):
8991
surrogate_models = self.generate_surrogates(self.specification, data_collector)
9092
fitness_functions = self.search_algorithm.generate_fitness_functions(surrogate_models)
91-
candidate_test_case, _fitness, surrogate = self.search_algorithm.search(
92-
fitness_functions, self.specification
93-
)
93+
candidate_test_case, _, surrogate = self.search_algorithm.search(fitness_functions, self.specification)
9494

9595
self.simulator.startup()
9696
test_result = self.simulator.run_with_config(candidate_test_case)
@@ -116,9 +116,9 @@ def execute(
116116
return "No fault found", i + 1, data_collector.data
117117

118118
def generate_surrogates(
119-
self, specification: CausalSpecification, data_collector: ObservationalDataCollector
119+
self, specification: CausalSpecification, data_collector: ObservationalDataCollector
120120
) -> list[SearchFitnessFunction]:
121-
""" Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
121+
"""Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
122122
:param specification: The Causal Specification (combination of Scenario and Causal Dag)
123123
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
124124
:return: A list of surrogate models

0 commit comments

Comments
 (0)