Skip to content

Commit 1e75b70

Browse files
authored
Merge pull request #287 from CITCOM-project/update-surrogate-tests
Update Unit Tests for Pandas versions > 2
2 parents bea5ef6 + f77a45c commit 1e75b70

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
55
from typing import Callable
6-
6+
import pandas as pd
77
from causal_testing.data_collection.data_collector import ObservationalDataCollector
88
from causal_testing.specification.causal_specification import CausalSpecification
99
from causal_testing.testing.base_test_case import BaseTestCase
1010
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
1111

12-
1312
@dataclass
1413
class SimulationResult:
1514
"""Data class holding the data and result metadata of a simulation"""
@@ -18,6 +17,11 @@ class SimulationResult:
1817
fault: bool
1918
relationship: str
2019

20+
def to_dataframe(self) -> pd.DataFrame:
21+
"""Convert the simulation result data to a pandas DataFrame"""
22+
data_as_lists = {k: v if isinstance(v, list) else [v] for k,v in self.data.items()}
23+
return pd.DataFrame(data_as_lists)
24+
2125

2226
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
2327
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
@@ -87,14 +91,14 @@ def execute(
8791

8892
self.simulator.startup()
8993
test_result = self.simulator.run_with_config(candidate_test_case)
94+
test_result_df = test_result.to_dataframe()
9095
self.simulator.shutdown()
9196

9297
if custom_data_aggregator is not None:
9398
if data_collector.data is not None:
9499
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data)
95100
else:
96-
data_collector.data = data_collector.data.append(test_result.data, ignore_index=True)
97-
101+
data_collector.data = pd.concat([data_collector.data, test_result_df], ignore_index=True)
98102
if test_result.fault:
99103
print(
100104
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with "

tests/surrogate_tests/test_causal_surrogate_assisted.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,4 +231,7 @@ def shutdown(self):
231231
pass
232232

233233
def data_double_aggregator(data, new_data):
234-
return data.append(new_data, ignore_index=True).append(new_data, ignore_index=True)
234+
"""Previously used data.append(new_data), however, pandas version >2 requires pd.concat() since append is now a private method.
235+
Converting new_data to a pd.DataFrame is required to use pd.concat(). """
236+
new_data = pd.DataFrame([new_data])
237+
return pd.concat([data, new_data, new_data], ignore_index=True)

0 commit comments

Comments
 (0)