Skip to content

Commit 62e6b3d

Browse files
committed
pytest
1 parent 3e25256 commit 62e6b3d

File tree

3 files changed

+46
-16
lines changed

3 files changed

+46
-16
lines changed

causal_testing/estimation/gp.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,33 +85,46 @@ def __init__(
8585
df: pd.DataFrame,
8686
features: list,
8787
outcome: str,
88+
max_order: int = 0,
8889
extra_operators: list = None,
8990
sympy_conversions: dict = None,
9091
seed=0,
9192
):
9293
# pylint: disable=too-many-arguments
9394
random.seed(seed)
94-
np.random.seed(seed)
9595
self.df = df
9696
self.features = features
9797
self.outcome = outcome
98+
self.max_order = max_order
9899
self.seed = seed
99100
self.pset = gp.PrimitiveSet("MAIN", len(self.features))
100101
self.pset.renameArguments(**{f"ARG{i}": f for i, f in enumerate(self.features)})
101102

102-
standard_operators = [(add, 2), (mul, 2), (reciprocal, 1)]
103+
standard_operators = [(add, 2), (mul, 2)]
103104
if extra_operators is None:
104105
extra_operators = [(log, 1), (reciprocal, 1)]
105-
for operator, num_args in standard_operators + extra_operators:
106-
self.pset.addPrimitive(operator, num_args)
107106
if sympy_conversions is None:
108107
sympy_conversions = {}
108+
for operator, num_args in standard_operators + extra_operators:
109+
self.pset.addPrimitive(operator, num_args)
110+
109111
self.sympy_conversions = {
110112
"mul": lambda x1, x2: f"Mul({x1},{x2})",
111113
"add": lambda x1, x2: f"Add({x1},{x2})",
112114
"reciprocal": lambda x1: f"Pow({x1},-1)",
113115
} | sympy_conversions
114116

117+
for i in range(self.max_order):
118+
print("Adding in order", i)
119+
name = f"power_{i}"
120+
self.pset.addPrimitive(lambda x: power(x, i), 1, name=name)
121+
if name in self.sympy_conversions:
122+
raise ValueError(
123+
f"You have provided a function called {name}, which is reserved for raising to power"
124+
f"{i}. Please choose a different name for your function."
125+
)
126+
self.sympy_conversions[name] = lambda x1: f"Pow({x1},{i})"
127+
115128
creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
116129
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin)
117130

causal_testing/estimation/linear_regression_estimator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,21 @@ def __init__(
6363
def gp_formula(
6464
self,
6565
ngen: int = 100,
66-
mu: int = 20,
67-
lambda_: int = 10,
66+
pop_size: int = 20,
67+
num_offspring: int = 10,
6868
extra_operators: list = None,
6969
sympy_conversions: dict = None,
70+
max_order: int = 0,
7071
seeds: list = None,
7172
seed: int = 0,
7273
):
73-
# pylint: disable=too-many-arguments,invalid-name
74+
# pylint: disable=too-many-arguments
7475
"""
7576
Use Genetic Programming (GP) to infer the regression equation from the data.
7677
7778
:param ngen: The maximum number of GP generations to run for.
7879
:param mu: The GP population size.
79-
:param lambda_: The number of offspring per generation.
80+
:param num_offspring: The number of offspring per generation.
8081
:param extra_operators: Additional operators for the GP (defaults are +, *, and 1/x). Operations should be of
8182
the form (fun, numArgs), e.g. (add, 2).
8283
:param sympy_conversions: Dictionary of conversions of extra_operators for sympy,
@@ -85,16 +86,17 @@ def gp_formula(
8586
probably logarithmic, you can put that in).
8687
:param seed: Random seed for the GP.
8788
"""
88-
gp = GP(
89+
self.gp = GP(
8990
df=self.df,
9091
features=sorted(list(self.adjustment_set.union([self.treatment]))),
9192
outcome=self.outcome,
9293
extra_operators=extra_operators,
9394
sympy_conversions=sympy_conversions,
9495
seed=seed,
96+
max_order=max_order,
9597
)
96-
formula = gp.run_gp(ngen=ngen, pop_size=mu, num_offspring=lambda_, seeds=seeds)
97-
formula = gp.simplify(formula)
98+
formula = self.gp.run_gp(ngen=ngen, pop_size=num_offspring, num_offspring=num_offspring, seeds=seeds)
99+
formula = self.gp.simplify(formula)
98100
self.formula = f"{self.outcome} ~ I({formula}) - 1"
99101

100102
def add_modelling_assumptions(self):

tests/estimation_tests/test_linear_regression_estimator.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from causal_testing.specification.capabilities import TreatmentSequence
88

99
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
10+
from causal_testing.estimation.gp import reciprocal
1011

1112

1213
def load_nhefs_df():
@@ -275,16 +276,30 @@ def test_gp(self):
275276
df["X"] = np.arange(10)
276277
df["Y"] = 1 / (df["X"] + 1)
277278
linear_regression_estimator = LinearRegressionEstimator("X", 0, 1, set(), "Y", df.astype(float))
278-
linear_regression_estimator.gp_formula(seed=1)
279-
self.assertEqual(
280-
linear_regression_estimator.formula,
281-
"Y ~ I((2.606801258739728e-17*X + 0.626132756132756)/(0.6261327561327561*X + 0.626132756132756)) - 1",
282-
)
279+
linear_regression_estimator.gp_formula(seeds=["reciprocal(add(X, 1))"])
280+
print("MAPPING")
281+
print(linear_regression_estimator.gp.pset.mapping)
282+
self.assertEqual(linear_regression_estimator.formula, "Y ~ I(1/(X + 1)) - 1")
283283
ate, (ci_low, ci_high) = linear_regression_estimator.estimate_ate_calculated()
284284
self.assertEqual(round(ate[0], 2), 0.50)
285285
self.assertEqual(round(ci_low[0], 2), 0.50)
286286
self.assertEqual(round(ci_high[0], 2), 0.50)
287287

288+
def test_gp_power(self):
289+
df = pd.DataFrame()
290+
df["X"] = np.arange(10)
291+
df["Y"] = 2 * (df["X"] ** 2)
292+
linear_regression_estimator = LinearRegressionEstimator("X", 0, 1, set(), "Y", df.astype(float))
293+
linear_regression_estimator.gp_formula(seed=1, max_order=0)
294+
self.assertEqual(
295+
linear_regression_estimator.formula,
296+
"Y ~ I(2.0*X**2 + 3.8205100524608823e-31) - 1",
297+
)
298+
ate, (ci_low, ci_high) = linear_regression_estimator.estimate_ate_calculated()
299+
self.assertEqual(round(ate[0], 2), -2.00)
300+
self.assertEqual(round(ci_low[0], 2), -2.00)
301+
self.assertEqual(round(ci_high[0], 2), -2.00)
302+
288303

289304
class TestLinearRegressionInteraction(unittest.TestCase):
290305
"""Test linear regression for estimating effects involving interaction."""

0 commit comments

Comments
 (0)