Skip to content

Commit 7bf3f4c

Browse files
authored
Merge pull request #300 from CITCOM-project/gp-fix
Improved GP fitness to use NRMSE
2 parents bb2527b + f22af96 commit 7bf3f4c

File tree

2 files changed

+66
-11
lines changed

2 files changed

+66
-11
lines changed

causal_testing/estimation/genetic_programming_regression_fitter.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def __init__(
149149
)
150150
self.sympy_conversions[name] = conversion
151151

152-
print(self.pset.mapping)
153152
creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
154153
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin)
155154

@@ -235,6 +234,8 @@ def simplify(self, expression: gp.PrimitiveTree) -> sympy.core.Expr:
235234
236235
:return: The simplified expression as a sympy Expr object.
237236
"""
237+
if isinstance(expression, str):
238+
expression = creator.Individual(gp.PrimitiveTree.from_string(expression, self.pset))
238239
return sympy.simplify(self._stringify_for_sympy(expression))
239240

240241
def repair(self, expression: gp.PrimitiveTree) -> gp.PrimitiveTree:
@@ -248,7 +249,7 @@ def repair(self, expression: gp.PrimitiveTree) -> gp.PrimitiveTree:
248249
"""
249250
eq = f"{self.outcome} ~ {' + '.join(str(x) for x in self.split(expression))}"
250251
try:
251-
# Create model, fit (run) it, give estimates from it]
252+
# Create model, fit (run) it, give estimates from it
252253
model = smf.ols(eq, self.df)
253254
res = model.fit()
254255

@@ -278,16 +279,25 @@ def fitness(self, expression: gp.PrimitiveTree) -> float:
278279
"""
279280
old_settings = np.seterr(all="raise")
280281
try:
281-
# Create model, fit (run) it, give estimates from it]
282+
if isinstance(expression, str):
283+
expression = creator.Individual(gp.PrimitiveTree.from_string(expression, self.pset))
284+
285+
# Create model, fit (run) it, give estimates from it
282286
func = gp.compile(expression, self.pset)
283-
y_estimates = pd.Series([func(**x) for _, x in self.df[self.features].iterrows()])
287+
y_estimates = pd.Series(
288+
[func(**x) for _, x in self.df[self.features].iterrows()],
289+
index=self.df.index,
290+
)
284291

285-
# Calc errors using an improved normalised mean squared
292+
# Calculate errors using the normalised root mean square error (nrmse),
293+
# which is normalised with respect to the range
286294
sqerrors = (self.df[self.outcome] - y_estimates) ** 2
287-
mean_squared = sqerrors.sum() / len(self.df)
288-
nmse = mean_squared / (self.df[self.outcome].sum() / len(self.df))
295+
nrmse = np.sqrt(sqerrors.sum() / len(self.df)) / (self.df[self.outcome].max() - self.df[self.outcome].min())
296+
297+
if pd.isnull(nrmse) or nrmse.real != nrmse:
298+
return (float("inf"),)
289299

290-
return (nmse,)
300+
return (nrmse,)
291301

292302
# Fitness value of infinite if error - not return 1
293303
except (
@@ -321,18 +331,29 @@ def make_offspring(self, population: list, num_offspring: int) -> list:
321331
offspring.append(child)
322332
return offspring
323333

324-
def run_gp(self, ngen: int, pop_size: int = 20, num_offspring: int = 10, seeds: list = None) -> gp.PrimitiveTree:
334+
# pylint: disable=too-many-arguments
335+
def run_gp(
336+
self,
337+
ngen: int,
338+
pop_size: int = 20,
339+
num_offspring: int = 10,
340+
seeds: list = None,
341+
repair: bool = True,
342+
) -> gp.PrimitiveTree:
325343
"""
326344
Execute Genetic Programming to find the best expression using a mu+lambda algorithm.
327345
328346
:param ngen: The maximum number of generations.
329347
:param pop_size: The population size.
330348
:param num_offspring: The number of new individuals per generation.
331349
:param seeds: Seed individuals for the initial population.
350+
:param repair: Whether to run the linear regression repair operator (defaults to True).
332351
333352
:return: The best candididate expression.
334353
"""
335-
population = [self.toolbox.repair(ind) for ind in self.toolbox.population(n=pop_size)]
354+
population = self.toolbox.population(n=pop_size)
355+
if repair:
356+
population = [self.toolbox.repair(ind) for ind in population]
336357
if seeds is not None:
337358
for seed in seeds:
338359
ind = creator.Individual(gp.PrimitiveTree.from_string(seed, self.pset))
@@ -348,7 +369,8 @@ def run_gp(self, ngen: int, pop_size: int = 20, num_offspring: int = 10, seeds:
348369
for _ in range(1, ngen + 1):
349370
# Vary the population
350371
offspring = self.make_offspring(population, num_offspring)
351-
offspring = [self.toolbox.repair(ind) for ind in offspring]
372+
if repair:
373+
offspring = [self.toolbox.repair(ind) for ind in offspring]
352374

353375
# Evaluate the individuals with an invalid fitness
354376
for ind in offspring:
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,43 @@
11
import unittest
22
import pandas as pd
3+
from operator import sub
34

45
from causal_testing.estimation.genetic_programming_regression_fitter import GP
56

67

8+
def root(x):
9+
return x**0.5
10+
11+
712
class TestGP(unittest.TestCase):
813
def test_init_invalid_fun_name(self):
914
with self.assertRaises(ValueError):
1015
GP(df=pd.DataFrame(), features=[], outcome="", max_order=2, sympy_conversions={"power_1": ""})
16+
17+
def test_simplify_string(self):
18+
gp = GP(
19+
df=None,
20+
features=["x1"],
21+
outcome=None,
22+
max_order=1,
23+
)
24+
self.assertEqual(str(gp.simplify("power_1(x1)")), "x1")
25+
26+
def test_fitness(self):
27+
gp = GP(
28+
df=pd.DataFrame({"x1": [1, 2, 3], "outcome": [2, 3, 4]}),
29+
features=["x1"],
30+
outcome="outcome",
31+
max_order=0,
32+
)
33+
self.assertEqual(gp.fitness("add(x1, 1)"), (0,))
34+
35+
def test_fitness_inf(self):
36+
gp = GP(
37+
df=pd.DataFrame({"x1": [1, 2, 3], "outcome": [2, 3, 4]}),
38+
features=["x1"],
39+
outcome="outcome",
40+
max_order=0,
41+
extra_operators=[(root, 1)],
42+
)
43+
self.assertEqual(gp.fitness("root(-1)"), (float("inf"),))

0 commit comments

Comments
 (0)