Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions causal_testing/surrogate/causal_surrogate_assisted.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable

import pandas as pd
from causal_testing.data_collection.data_collector import ObservationalDataCollector
from causal_testing.specification.causal_specification import CausalSpecification
from causal_testing.testing.base_test_case import BaseTestCase
from causal_testing.testing.estimators import CubicSplineRegressionEstimator


@dataclass
class SimulationResult:
"""Data class holding the data and result metadata of a simulation"""
Expand All @@ -18,6 +17,11 @@ class SimulationResult:
fault: bool
relationship: str

def to_dataframe(self) -> pd.DataFrame:
"""Convert the simulation result data to a pandas DataFrame"""
data_as_lists = {k: v if isinstance(v, list) else [v] for k,v in self.data.items()}
return pd.DataFrame(data_as_lists)


class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
Expand Down Expand Up @@ -87,14 +91,14 @@ def execute(

self.simulator.startup()
test_result = self.simulator.run_with_config(candidate_test_case)
test_result_df = test_result.to_dataframe()
self.simulator.shutdown()

if custom_data_aggregator is not None:
if data_collector.data is not None:
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data)
else:
data_collector.data = data_collector.data.append(test_result.data, ignore_index=True)

data_collector.data = pd.concat([data_collector.data, test_result_df], ignore_index=True)
if test_result.fault:
print(
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with "
Expand Down
5 changes: 4 additions & 1 deletion tests/surrogate_tests/test_causal_surrogate_assisted.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,7 @@ def shutdown(self):
pass

def data_double_aggregator(data, new_data):
return data.append(new_data, ignore_index=True).append(new_data, ignore_index=True)
"""Previously used data.append(new_data), however, pandas version >2 requires pd.concat() since append is now a private method.
Converting new_data to a pd.DataFrame is required to use pd.concat(). """
new_data = pd.DataFrame([new_data])
return pd.concat([data, new_data, new_data], ignore_index=True)
Loading