Skip to content

Commit 08dea8f

Browse files
committed
Fix gradient bug in models with max operation
1 parent f90e44f commit 08dea8f

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

pymc/logprob/rewriting.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,6 @@ def incsubtensor_rv_replace(fgraph, node):
365365
"local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic"
366366
)
367367
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")
368-
# Split max_and_argmax
369-
logprob_rewrites_db.register("local_max_and_argmax", out2in(local_max_and_argmax), "basic")
370368

371369
# These rewrites convert un-measurable variables into their measurable forms,
372370
# but they need to be reapplied, because some of the measurable forms require
@@ -376,6 +374,12 @@ def incsubtensor_rv_replace(fgraph, node):
376374

377375
logprob_rewrites_db.register("measurable_ir_rewrites", measurable_ir_rewrites_db, "basic")
378376

377+
# Split max_and_argmax
378+
# We only register this in the measurable IR db because max does not have a grad implemented
379+
# And running this on any MaxAndArgmax would lead to issues: https://github.com/pymc-devs/pymc/issues/7251
380+
# This special registering can be removed after https://github.com/pymc-devs/pytensor/issues/334 is fixed
381+
measurable_ir_rewrites_db.register("local_max_and_argmax", local_max_and_argmax, "basic")
382+
379383
# These rewrites push random/measurable variables "down", making them closer to
380384
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
381385
# "up" through the random/measurable variables and into their inputs.

tests/logprob/test_order.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import pymc as pm
4646

4747
from pymc import logp
48+
from pymc.logprob import conditional_logp
4849
from pymc.testing import assert_no_rvs
4950

5051

@@ -293,3 +294,18 @@ def test_min_max_bernoulli():
293294
min_logp_fn = pytensor.function([value], pm.logp(pt.min(x), value))
294295
np.testing.assert_allclose(min_logp_fn(1), np.log(p**n))
295296
np.testing.assert_allclose(min_logp_fn(0), np.log(1 - p**n))
297+
298+
299+
def test_non_measurable_max_grad():
300+
# Regression test for https://github.com/pymc-devs/pytensor/issues/711
301+
x = pt.random.normal(0, 1, size=(3,))
302+
max_x = x.max()
303+
y = pt.random.normal(max_x, 1)
304+
305+
x_vv = x.type()
306+
y_vv = y.type()
307+
logp_terms = conditional_logp({x: x_vv, y: y_vv}).values()
308+
joint_logp = pt.sum([term.sum() for term in logp_terms])
309+
310+
# Test that calling gradient does not raise a NotImplementedError
311+
assert pt.grad(joint_logp, x_vv)

0 commit comments

Comments
 (0)