Skip to content

Commit 5d915ed

Browse files
committed
pylint
1 parent 62e6b3d commit 5d915ed

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

causal_testing/estimation/gp.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ def mut_insert(expression: gp.PrimitiveTree, pset: gp.PrimitiveSet):
7575
return (expression,)
7676

7777

78+
def create_power_function(power):
79+
def power_func(x):
80+
return power(x, power)
81+
82+
def sympy_conversion(x1):
83+
return f"Pow({x1},{i})"
84+
85+
return power_func, sympy_conversion
86+
87+
7888
class GP:
7989
"""
8090
Object to perform genetic programming.
@@ -90,7 +100,7 @@ def __init__(
90100
sympy_conversions: dict = None,
91101
seed=0,
92102
):
93-
# pylint: disable=too-many-arguments
103+
# pylint: disable=too-many-arguments,too-many-instance-attributes
94104
random.seed(seed)
95105
self.df = df
96106
self.features = features
@@ -115,15 +125,15 @@ def __init__(
115125
} | sympy_conversions
116126

117127
for i in range(self.max_order):
118-
print("Adding in order", i)
119128
name = f"power_{i}"
120-
self.pset.addPrimitive(lambda x: power(x, i), 1, name=name)
129+
func, conversion = create_power_function(i)
130+
self.pset.addPrimitive(func, 1, name=name)
121131
if name in self.sympy_conversions:
122132
raise ValueError(
123133
f"You have provided a function called {name}, which is reserved for raising to power"
124134
f"{i}. Please choose a different name for your function."
125135
)
126-
self.sympy_conversions[name] = lambda x1: f"Pow({x1},{i})"
136+
self.sympy_conversions[name] = conversion
127137

128138
creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
129139
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin)

causal_testing/estimation/linear_regression_estimator.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def gp_formula(
6565
ngen: int = 100,
6666
pop_size: int = 20,
6767
num_offspring: int = 10,
68+
max_order: int = 0,
6869
extra_operators: list = None,
6970
sympy_conversions: dict = None,
70-
max_order: int = 0,
7171
seeds: list = None,
7272
seed: int = 0,
7373
):
@@ -76,17 +76,18 @@ def gp_formula(
7676
Use Genetic Programming (GP) to infer the regression equation from the data.
7777
7878
:param ngen: The maximum number of GP generations to run for.
79-
:param mu: The GP population size.
79+
:param pop_size: The GP population size.
8080
:param num_offspring: The number of offspring per generation.
81-
:param extra_operators: Additional operators for the GP (defaults are +, *, and 1/x). Operations should be of
81+
:param max_order: The maximum polynomial order to use, e.g. `max_order=2` will give polynomials of the form `ax^2 + bx + c`.
82+
:param extra_operators: Additional operators for the GP (defaults are +, *, log(x), and 1/x). Operations should be of
8283
the form (fun, numArgs), e.g. (add, 2).
8384
:param sympy_conversions: Dictionary of conversions of extra_operators for sympy,
8485
e.g. `"mul": lambda *args_: "Mul({},{})".format(*args_)`.
8586
:param seeds: Seed individuals for the population (e.g. if you think that the relationship between X and Y is
8687
probably logarithmic, you can put that in).
8788
:param seed: Random seed for the GP.
8889
"""
89-
self.gp = GP(
90+
gp = GP(
9091
df=self.df,
9192
features=sorted(list(self.adjustment_set.union([self.treatment]))),
9293
outcome=self.outcome,
@@ -95,8 +96,8 @@ def gp_formula(
9596
seed=seed,
9697
max_order=max_order,
9798
)
98-
formula = self.gp.run_gp(ngen=ngen, pop_size=num_offspring, num_offspring=num_offspring, seeds=seeds)
99-
formula = self.gp.simplify(formula)
99+
formula = gp.run_gp(ngen=ngen, pop_size=pop_size, num_offspring=num_offspring, seeds=seeds)
100+
formula = gp.simplify(formula)
100101
self.formula = f"{self.outcome} ~ I({formula}) - 1"
101102

102103
def add_modelling_assumptions(self):

tests/estimation_tests/test_linear_regression_estimator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,6 @@ def test_gp(self):
277277
df["Y"] = 1 / (df["X"] + 1)
278278
linear_regression_estimator = LinearRegressionEstimator("X", 0, 1, set(), "Y", df.astype(float))
279279
linear_regression_estimator.gp_formula(seeds=["reciprocal(add(X, 1))"])
280-
print("MAPPING")
281-
print(linear_regression_estimator.gp.pset.mapping)
282280
self.assertEqual(linear_regression_estimator.formula, "Y ~ I(1/(X + 1)) - 1")
283281
ate, (ci_low, ci_high) = linear_regression_estimator.estimate_ate_calculated()
284282
self.assertEqual(round(ate[0], 2), 0.50)

0 commit comments

Comments
 (0)