Skip to content

Commit fe6e629

Browse files
committed
Fix stat_quantile to handle formula that use environment variables
1 parent 857118e commit fe6e629

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

plotnine/stats/stat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class stat(ABC, metaclass=Register):
6464

6565
# Plot namespace, it gets its value when the plot is being
6666
# built.
67-
environment: Environment | None = None
67+
environment: Environment
6868

6969
def __init__(
7070
self,

plotnine/stats/stat_quantile.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def setup_params(self, data):
6363
if params["formula"] is None:
6464
params["formula"] = "y ~ x"
6565
warn("Formula not specified, using '{}'", PlotnineWarning)
66+
else:
67+
from patsy.eval import EvalEnvironment
68+
69+
params["eval_env"] = EvalEnvironment(
70+
namespaces=self.environment.namespaces
71+
)
72+
6673
try:
6774
iter(params["quantiles"])
6875
except TypeError:
@@ -81,7 +88,11 @@ def quant_pred(q, data, params):
8188
"""
8289
import statsmodels.formula.api as smf
8390

84-
mod = smf.quantreg(params["formula"], data)
91+
mod = smf.quantreg(
92+
params["formula"],
93+
data,
94+
eval_env=params.get("eval_env"),
95+
)
8596
reg_res = mod.fit(q=q, **params["method_args"])
8697
out = pd.DataFrame(
8798
{

tests/test_geom_quantile.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,18 @@ def test_lines():
1818
+ geom_quantile(quantiles=[0.001, 0.5, 0.999], formula="y~x", size=2)
1919
)
2020

21+
# np.absolute tests the ability to pickup variables in the
22+
# caller environment
23+
p2 = (
24+
ggplot(data, aes(x="x", y="y"))
25+
+ geom_point(alpha=0.5)
26+
+ geom_quantile(
27+
quantiles=[0.001, 0.5, 0.999], formula="y~np.absolute(x)", size=2
28+
)
29+
)
30+
2131
# Two (.001, .999) quantile lines should bound the points
2232
# from below and from above, and the .5 line should go
2333
# through middle (approximately).
2434
assert p == "lines"
35+
assert p2 == "lines"

0 commit comments

Comments
 (0)