Skip to content

Commit 790c12b

Browse files
committed
GP in
1 parent 2b7042d commit 790c12b

File tree

4 files changed

+350
-1
lines changed

4 files changed

+350
-1
lines changed

causal_testing/gp/gp.py

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
import random
2+
import warnings
3+
import patsy
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import pandas as pd
8+
import statsmodels.formula.api as smf
9+
import statsmodels
10+
import sympy
11+
import copy
12+
13+
from functools import partial
14+
from deap import algorithms, base, creator, tools, gp
15+
16+
from numpy import negative, exp, power, log, sin, cos, tan, sinh, cosh, tanh
17+
from inspect import isclass
18+
19+
from operator import add, mul
20+
21+
22+
def root(x):
23+
return power(x, 0.5)
24+
25+
26+
def square(x):
27+
return power(x, 2)
28+
29+
30+
def cube(x):
31+
return power(x, 3)
32+
33+
34+
def fourth_power(x):
35+
return power(x, 4)
36+
37+
38+
def reciprocal(x):
39+
return power(x, -1)
40+
41+
42+
def mutInsert(individual, pset):
43+
"""
44+
Copied from gp.mutInsert, except that we import isclass from inspect, so we
45+
won't have the "isclass not defined" bug.
46+
47+
Inserts a new branch at a random position in *individual*. The subtree
48+
at the chosen position is used as child node of the created subtree, in
49+
that way, it is really an insertion rather than a replacement. Note that
50+
the original subtree will become one of the children of the new primitive
51+
inserted, but not perforce the first (its position is randomly selected if
52+
the new primitive has more than one child).
53+
54+
:param individual: The normal or typed tree to be mutated.
55+
:returns: A tuple of one tree.
56+
"""
57+
index = random.randrange(len(individual))
58+
node = individual[index]
59+
slice_ = individual.searchSubtree(index)
60+
choice = random.choice
61+
62+
# As we want to keep the current node as children of the new one,
63+
# it must accept the return value of the current node
64+
primitives = [p for p in pset.primitives[node.ret] if node.ret in p.args]
65+
66+
if len(primitives) == 0:
67+
return (individual,)
68+
69+
new_node = choice(primitives)
70+
new_subtree = [None] * len(new_node.args)
71+
position = choice([i for i, a in enumerate(new_node.args) if a == node.ret])
72+
73+
for i, arg_type in enumerate(new_node.args):
74+
if i != position:
75+
term = choice(pset.terminals[arg_type])
76+
if isclass(term):
77+
term = term()
78+
new_subtree[i] = term
79+
80+
new_subtree[position : position + 1] = individual[slice_]
81+
new_subtree.insert(0, new_node)
82+
individual[slice_] = new_subtree
83+
return (individual,)
84+
85+
86+
class GP:
87+
88+
def __init__(
89+
self,
90+
df: pd.DataFrame,
91+
features: list,
92+
outcome: str,
93+
extra_operators: list = None,
94+
sympy_conversions: dict = None,
95+
seed=0,
96+
):
97+
random.seed(seed)
98+
self.df = df
99+
self.features = features
100+
self.outcome = outcome
101+
self.seed = seed
102+
self.pset = gp.PrimitiveSet("MAIN", len(self.features))
103+
self.pset.renameArguments(**{f"ARG{i}": f for i, f in enumerate(self.features)})
104+
105+
standard_operators = [(add, 2), (mul, 2), (reciprocal, 1)]
106+
if extra_operators is None:
107+
extra_operators = [(log, 1), (reciprocal, 1)]
108+
for operator, num_args in standard_operators + extra_operators:
109+
self.pset.addPrimitive(operator, num_args)
110+
if sympy_conversions is None:
111+
sympy_conversions = {}
112+
self.sympy_conversions = {
113+
"mul": lambda *args_: "Mul({},{})".format(*args_),
114+
"add": lambda *args_: "Add({},{})".format(*args_),
115+
"reciprocal": lambda *args_: "Pow({},-1)".format(*args_),
116+
} | sympy_conversions
117+
118+
creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
119+
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin)
120+
121+
self.toolbox = base.Toolbox()
122+
self.toolbox.register("expr", gp.genHalfAndHalf, pset=self.pset, min_=1, max_=2)
123+
self.toolbox.register("individual", tools.initIterate, creator.Individual, self.toolbox.expr)
124+
self.toolbox.register("population", tools.initRepeat, list, self.toolbox.individual)
125+
self.toolbox.register("compile", gp.compile, pset=self.pset)
126+
self.toolbox.register("evaluate", self.evalSymbReg)
127+
self.toolbox.register("repair", self.repair)
128+
self.toolbox.register("select", tools.selBest)
129+
self.toolbox.register("mate", gp.cxOnePoint)
130+
self.toolbox.register("expr_mut", gp.genFull, min_=0, max_=2)
131+
self.toolbox.register("mutate", self.mutate, expr=self.toolbox.expr_mut)
132+
self.toolbox.decorate("mate", gp.staticLimit(key=lambda x: x.height + 1, max_value=17))
133+
self.toolbox.decorate("mutate", gp.staticLimit(key=lambda x: x.height + 1, max_value=17))
134+
135+
def split(self, individual):
136+
if len(individual) > 1:
137+
terms = []
138+
# Recurse over children if add/sub
139+
if individual[0].name in ["add", "sub"]:
140+
terms.extend(
141+
self.split(
142+
creator.Individual(
143+
gp.PrimitiveTree(
144+
individual[individual.searchSubtree(1).start : individual.searchSubtree(1).stop]
145+
)
146+
)
147+
)
148+
)
149+
terms.extend(
150+
self.split(creator.Individual(gp.PrimitiveTree(individual[individual.searchSubtree(1).stop :])))
151+
)
152+
else:
153+
terms.append(individual)
154+
return terms
155+
return [individual]
156+
157+
def _convert_inverse_prim(self, prim, args):
158+
"""
159+
Convert inverse prims according to:
160+
[Dd]iv(a,b) -> Mul[a, 1/b]
161+
[Ss]ub(a,b) -> Add[a, -b]
162+
We achieve this by overwriting the corresponding format method of the sub and div prim.
163+
"""
164+
prim = copy.copy(prim)
165+
prim_formatter = self.sympy_conversions.get(prim.name, prim.format)
166+
167+
return prim_formatter(*args)
168+
169+
def _stringify_for_sympy(self, f):
170+
"""Return the expression in a human readable string."""
171+
string = ""
172+
stack = []
173+
for node in f:
174+
stack.append((node, []))
175+
while len(stack[-1][1]) == stack[-1][0].arity:
176+
prim, args = stack.pop()
177+
string = self._convert_inverse_prim(prim, args)
178+
if len(stack) == 0:
179+
break # If stack is empty, all nodes should have been seen
180+
stack[-1][1].append(string)
181+
return string
182+
183+
def simplify(self, individual):
184+
return sympy.simplify(self._stringify_for_sympy(individual))
185+
186+
def repair(self, individual):
187+
eq = f"{self.outcome} ~ {' + '.join(str(x) for x in self.split(individual))}"
188+
try:
189+
# Create model, fit (run) it, give estimates from it]
190+
model = smf.ols(eq, self.df)
191+
res = model.fit()
192+
y_estimates = res.predict(self.df)
193+
194+
eqn = f"{res.params['Intercept']}"
195+
for term, coefficient in res.params.items():
196+
if term != "Intercept":
197+
eqn = f"add({eqn}, mul({coefficient}, {term}))"
198+
repaired = type(individual)(gp.PrimitiveTree.from_string(eqn, self.pset))
199+
return repaired
200+
except (
201+
OverflowError,
202+
ValueError,
203+
ZeroDivisionError,
204+
statsmodels.tools.sm_exceptions.MissingDataError,
205+
patsy.PatsyError,
206+
) as e:
207+
return individual
208+
209+
def evalSymbReg(self, individual):
210+
old_settings = np.seterr(all="raise")
211+
try:
212+
# Create model, fit (run) it, give estimates from it]
213+
func = gp.compile(individual, self.pset)
214+
y_estimates = pd.Series([func(**x) for _, x in self.df[self.features].iterrows()])
215+
216+
# Calc errors using an improved normalised mean squared
217+
sqerrors = (self.df[self.outcome] - y_estimates) ** 2
218+
mean_squared = sqerrors.sum() / len(self.df)
219+
nmse = mean_squared / (self.df[self.outcome].sum() / len(self.df))
220+
221+
return (nmse,)
222+
223+
# Fitness value of infinite if error - not return 1
224+
except (
225+
OverflowError,
226+
ValueError,
227+
ZeroDivisionError,
228+
statsmodels.tools.sm_exceptions.MissingDataError,
229+
patsy.PatsyError,
230+
RuntimeWarning,
231+
FloatingPointError,
232+
) as e:
233+
return (float("inf"),)
234+
finally:
235+
np.seterr(**old_settings) # Restore original settings
236+
237+
def make_offspring(self, population, lambda_):
238+
offspring = []
239+
for i in range(lambda_):
240+
parent1, parent2 = tools.selTournament(population, 2, 2)
241+
child, _ = self.toolbox.mate(self.toolbox.clone(parent1), self.toolbox.clone(parent2))
242+
del child.fitness.values
243+
(child,) = self.toolbox.mutate(child)
244+
offspring.append(child)
245+
return offspring
246+
247+
def eaMuPlusLambda(self, ngen, mu=20, lambda_=10, stats=None, verbose=False, seeds=None):
248+
population = [self.toolbox.repair(ind) for ind in self.toolbox.population(n=mu)]
249+
if seeds is not None:
250+
for seed in seeds:
251+
ind = creator.Individual(gp.PrimitiveTree.from_string(seed, self.pset))
252+
ind.fitness.values = self.toolbox.evaluate(ind)
253+
population.append(ind)
254+
255+
logbook = tools.Logbook()
256+
logbook.header = ["gen", "nevals"] + (stats.fields if stats else [])
257+
258+
# Evaluate the individuals with an invalid fitness
259+
for ind in population:
260+
ind.fitness.values = self.toolbox.evaluate(ind)
261+
population.sort(key=lambda x: (x.fitness.values, x.height))
262+
263+
record = stats.compile(population) if stats is not None else {}
264+
logbook.record(gen=0, nevals=len(population), **record)
265+
if verbose:
266+
print(logbook.stream)
267+
268+
# Begin the generational process
269+
for gen in range(1, ngen + 1):
270+
# Vary the population
271+
offspring = self.make_offspring(population, lambda_)
272+
offspring = [self.toolbox.repair(ind) for ind in offspring]
273+
274+
# Evaluate the individuals with an invalid fitness
275+
for ind in offspring:
276+
ind.fitness.values = self.toolbox.evaluate(ind)
277+
278+
# Select the next generation population
279+
population[:] = self.toolbox.select(population + offspring, mu)
280+
281+
# Update the statistics with the new population
282+
record = stats.compile(population) if stats is not None else {}
283+
logbook.record(gen=gen, nevals=len(offspring), **record)
284+
if verbose:
285+
print(logbook.stream)
286+
population.sort(key=lambda x: (x.fitness.values, x.height))
287+
288+
return population[0]
289+
290+
def mutate(self, individual, expr):
291+
choice = random.randint(1, 3)
292+
if choice == 1:
293+
mutated = gp.mutNodeReplacement(self.toolbox.clone(individual), self.pset)
294+
elif choice == 2:
295+
mutated = mutInsert(self.toolbox.clone(individual), self.pset)
296+
elif choice == 3:
297+
mutated = gp.mutShrink(self.toolbox.clone(individual))
298+
else:
299+
raise ValueError("Invalid mutation choice")
300+
return mutated
301+
302+
303+
if __name__ == "__main__":
304+
df = pd.DataFrame()
305+
df["X"] = np.arange(10)
306+
df["Y"] = 1 / (df.X + 1)
307+
308+
gp1 = GP(df.astype(float), ["X"], "Y", seed=1)
309+
best = gp1.eaMuPlusLambda(ngen=100)
310+
print(best, best.fitness.values[0])
311+
simplified = gp1.simplify(best)
312+
print(simplified)

causal_testing/testing/estimators.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from causal_testing.specification.variable import Variable
2020
from causal_testing.specification.capabilities import TreatmentSequence, Capability
21+
from causal_testing.gp.gp import GP
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -333,6 +334,19 @@ def __init__(
333334
for term in self.effect_modifiers:
334335
self.adjustment_set.add(term)
335336

337+
def gp_formula(self, ngen=100, mu=20, lambda_=10, extra_operators=None, sympy_conversions=None, seeds=None, seed=0):
338+
gp = GP(
339+
df=self.df,
340+
features=sorted(list(self.adjustment_set.union([self.treatment]))),
341+
outcome=self.outcome,
342+
extra_operators=extra_operators,
343+
sympy_conversions=sympy_conversions,
344+
seed=seed,
345+
)
346+
formula = gp.eaMuPlusLambda(ngen=ngen, mu=mu, lambda_=lambda_, seeds=seeds)
347+
formula = gp.simplify(formula)
348+
self.formula = f"{self.outcome} ~ I({formula}) - 1"
349+
336350
def add_modelling_assumptions(self):
337351
"""
338352
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
@@ -421,7 +435,13 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
421435
if str(x.dtypes[col]) == "object":
422436
x = pd.get_dummies(x, columns=[col], drop_first=True)
423437
x = x[model.params.index]
438+
439+
# This is a hack for "I(...)" equations
440+
x[self.treatment] = [self.treatment_value, self.control_value]
441+
424442
y = model.get_prediction(x).summary_frame()
443+
print("=== Y ===")
444+
print(y)
425445

426446
return y.iloc[1], y.iloc[0]
427447

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ dependencies = [
2727
"statsmodels~=0.14",
2828
"tabulate~=0.9",
2929
"pydot~=2.0",
30-
"pygad~=3.3"
30+
"pygad~=3.3",
31+
"deap~=1.4.1",
32+
"sympy~=1.13.1",
3133
]
3234
dynamic = ["version"]
3335

tests/testing_tests/test_estimators.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,21 @@ def test_program_11_2_with_robustness_validation(self):
402402
cv = CausalValidator()
403403
self.assertEqual(round(cv.estimate_robustness(model)["treatments"], 4), 0.7353)
404404

405+
def test_gp(self):
406+
df = pd.DataFrame()
407+
df["X"] = np.arange(10)
408+
df["Y"] = 1 / (df["X"] + 1)
409+
linear_regression_estimator = LinearRegressionEstimator("X", 0, 1, set(), "Y", df.astype(float))
410+
linear_regression_estimator.gp_formula(seed=1)
411+
self.assertEqual(
412+
linear_regression_estimator.formula,
413+
"Y ~ I((2.606801258739728e-17*X + 0.626132756132756)/(0.6261327561327561*X + 0.626132756132756)) - 1",
414+
)
415+
ate, (ci_low, ci_high) = linear_regression_estimator.estimate_ate_calculated()
416+
self.assertEqual(round(ate[0], 2), 0.50)
417+
self.assertEqual(round(ci_low[0], 2), 0.50)
418+
self.assertEqual(round(ci_high[0], 2), 0.50)
419+
405420

406421
class TestCubicSplineRegressionEstimator(TestLinearRegressionEstimator):
407422
@classmethod

0 commit comments

Comments
 (0)