From f77a45ccf95774936a49bf3308fcd5b1c1b5ec8f Mon Sep 17 00:00:00 2001 From: Farhad Allian Date: Fri, 26 Jul 2024 16:02:38 +0100 Subject: [PATCH] fix: tests for pandas version > 2 --- .../surrogate/causal_surrogate_assisted.py | 12 ++++++++---- .../test_causal_surrogate_assisted.py | 5 ++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/causal_testing/surrogate/causal_surrogate_assisted.py b/causal_testing/surrogate/causal_surrogate_assisted.py index 74f309be..c30b2086 100644 --- a/causal_testing/surrogate/causal_surrogate_assisted.py +++ b/causal_testing/surrogate/causal_surrogate_assisted.py @@ -3,13 +3,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Callable - +import pandas as pd from causal_testing.data_collection.data_collector import ObservationalDataCollector from causal_testing.specification.causal_specification import CausalSpecification from causal_testing.testing.base_test_case import BaseTestCase from causal_testing.testing.estimators import CubicSplineRegressionEstimator - @dataclass class SimulationResult: """Data class holding the data and result metadata of a simulation""" @@ -18,6 +17,11 @@ class SimulationResult: fault: bool relationship: str + def to_dataframe(self) -> pd.DataFrame: + """Convert the simulation result data to a pandas DataFrame""" + data_as_lists = {k: v if isinstance(v, list) else [v] for k,v in self.data.items()} + return pd.DataFrame(data_as_lists) + class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods """Class to be inherited with the search algorithm consisting of a search function and the fitness function of the @@ -87,14 +91,14 @@ def execute( self.simulator.startup() test_result = self.simulator.run_with_config(candidate_test_case) + test_result_df = test_result.to_dataframe() self.simulator.shutdown() if custom_data_aggregator is not None: if data_collector.data is not None: data_collector.data = custom_data_aggregator(data_collector.data, test_result.data) else: - data_collector.data = data_collector.data.append(test_result.data, ignore_index=True) - + data_collector.data = pd.concat([data_collector.data, test_result_df], ignore_index=True) if test_result.fault: print( f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with " diff --git a/tests/surrogate_tests/test_causal_surrogate_assisted.py b/tests/surrogate_tests/test_causal_surrogate_assisted.py index c5eb6e2c..54c93af1 100644 --- a/tests/surrogate_tests/test_causal_surrogate_assisted.py +++ b/tests/surrogate_tests/test_causal_surrogate_assisted.py @@ -231,4 +231,7 @@ def shutdown(self): pass def data_double_aggregator(data, new_data): - return data.append(new_data, ignore_index=True).append(new_data, ignore_index=True) \ No newline at end of file + """Previously used data.append(new_data), however, pandas version >2 requires pd.concat() since append is now a private method. + Converting new_data to a pd.DataFrame is required to use pd.concat(). """ + new_data = pd.DataFrame([new_data]) + return pd.concat([data, new_data, new_data], ignore_index=True)