Skip to content

Commit 7692084

Browse files
Add all docstrings to causal_surrogate_assisted.py
1 parent 051861e commit 7692084

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def shutdown(self, **kwargs):
5454
def run_with_config(self, configuration) -> SimulationResult:
5555
"""Run the simulator with the given configuration and return the results in the structure of a
5656
SimulationResult
57-
:param configuration:
57+
:param configuration: the configuration required to initialise the Simulation
5858
:return: Simulation results in the structure of the SimulationResult data class"""
5959

6060

@@ -77,6 +77,12 @@ def execute(
7777
max_executions: int = 200,
7878
custom_data_aggregator: Callable[[dict, dict], dict] = None,
7979
):
80+
""" For this specific test case, collect the data, run the simulator, check for faults and return the result
81+
and collected data
82+
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
83+
:param max_executions: Maximum number of executions
84+
:param custom_data_aggregator:
85+
:return: tuple containing SimulationResult or str, execution number and collected data """
8086
data_collector.collect_data()
8187

8288
for i in range(max_executions):
@@ -112,6 +118,11 @@ def execute(
112118
def generate_surrogates(
113119
self, specification: CausalSpecification, data_collector: ObservationalDataCollector
114120
) -> list[SearchFitnessFunction]:
121+
""" Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
122+
:param specification: The Causal Specification (combination of Scenario and Causal Dag)
123+
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
124+
:return: A list of surrogate models
125+
"""
115126
surrogate_models = []
116127

117128
for u, v in specification.causal_dag.graph.edges:

0 commit comments

Comments
 (0)