Skip to content

Commit 968fc89

Browse files
committed
Seeding gp power
1 parent 03f1bd8 commit 968fc89

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

causal_testing/estimation/gp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
"reciprocal": lambda x1: f"Pow({x1},-1)",
135135
} | sympy_conversions
136136

137-
for i in range(self.max_order):
137+
for i in range(self.max_order + 1):
138138
name = f"power_{i}"
139139
func, conversion = create_power_function(i)
140140
self.pset.addPrimitive(func, 1, name=name)
@@ -145,6 +145,7 @@ def __init__(
145145
)
146146
self.sympy_conversions[name] = conversion
147147

148+
print(self.pset.mapping)
148149
creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
149150
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin)
150151

tests/estimation_tests/test_linear_regression_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,10 @@ def test_gp_power(self):
303303
df["X"] = np.arange(10)
304304
df["Y"] = 2 * (df["X"] ** 2)
305305
linear_regression_estimator = LinearRegressionEstimator("X", 0, 1, set(), "Y", df.astype(float))
306-
linear_regression_estimator.gp_formula(seed=1, max_order=0)
306+
linear_regression_estimator.gp_formula(seed=1, max_order=2, seeds=["mul(2, power_2(X))"])
307307
self.assertEqual(
308308
linear_regression_estimator.formula,
309-
"Y ~ I(1.9999999999999999*X**2 - 1.0043240235058056e-116*X + 2.6645352591003757e-15) - 1",
309+
"Y ~ I(2*X**2) - 1",
310310
)
311311
ate, (ci_low, ci_high) = linear_regression_estimator.estimate_ate_calculated()
312312
self.assertEqual(round(ate[0], 2), -2.00)

0 commit comments

Comments
 (0)