Skip to content

Commit bffe827

Browse files
committed
fix: tests in pandas version > 2
1 parent cff1776 commit bffe827

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from causal_testing.specification.causal_specification import CausalSpecification
99
from causal_testing.testing.base_test_case import BaseTestCase
1010
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
11-
11+
import pandas as pd
1212

1313
@dataclass
1414
class SimulationResult:
@@ -18,6 +18,11 @@ class SimulationResult:
1818
fault: bool
1919
relationship: str
2020

21+
def to_dataframe(self) -> pd.DataFrame:
22+
"""Convert the simulation result data to a pandas DataFrame"""
23+
data_as_lists = {k: v if isinstance(v, list) else [v] for k,v in self.data.items()}
24+
return pd.DataFrame(data_as_lists)
25+
2126

2227
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
2328
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
@@ -87,14 +92,14 @@ def execute(
8792

8893
self.simulator.startup()
8994
test_result = self.simulator.run_with_config(candidate_test_case)
95+
test_result_df = test_result.to_dataframe()
9096
self.simulator.shutdown()
9197

9298
if custom_data_aggregator is not None:
9399
if data_collector.data is not None:
94100
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data)
95101
else:
96-
data_collector.data = data_collector.data.append(test_result.data, ignore_index=True)
97-
102+
data_collector.data = pd.concat([data_collector.data, test_result_df], ignore_index=True)
98103
if test_result.fault:
99104
print(
100105
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with "

tests/surrogate_tests/test_causal_surrogate_assisted.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,4 +231,7 @@ def shutdown(self):
231231
pass
232232

233233
def data_double_aggregator(data, new_data):
234-
return data.append(new_data, ignore_index=True).append(new_data, ignore_index=True)
234+
"""Previously used data.append(new_data), however, pandas version >2 requires pd.concat() since append is now a private method.
235+
Converting new_data to a pd.DataFrame is required to use pd.concat(). """
236+
new_data = pd.DataFrame([new_data])
237+
return pd.concat([data, new_data, new_data], ignore_index=True)

0 commit comments

Comments
 (0)