File tree Expand file tree Collapse file tree 2 files changed +22
-2
lines changed Expand file tree Collapse file tree 2 files changed +22
-2
lines changed Original file line number Diff line number Diff line change @@ -365,8 +365,6 @@ def incsubtensor_rv_replace(fgraph, node):
365
365
"local_exp_over_1_plus_exp" , out2in (local_exp_over_1_plus_exp ), "basic"
366
366
)
367
367
logprob_rewrites_db .register ("pre-canonicalize" , optdb .query ("+canonicalize" ), "basic" )
368
- # Split max_and_argmax
369
- logprob_rewrites_db .register ("local_max_and_argmax" , out2in (local_max_and_argmax ), "basic" )
370
368
371
369
# These rewrites convert un-measurable variables into their measurable forms,
372
370
# but they need to be reapplied, because some of the measurable forms require
@@ -376,6 +374,12 @@ def incsubtensor_rv_replace(fgraph, node):
376
374
377
375
logprob_rewrites_db .register ("measurable_ir_rewrites" , measurable_ir_rewrites_db , "basic" )
378
376
377
+ # Split max_and_argmax
378
+ # We only register this in the measurable IR db because max does not have a grad implemented
379
+ # And running this on any MaxAndArgmax would lead to issues: https://github.com/pymc-devs/pymc/issues/7251
380
+ # This special registering can be removed after https://github.com/pymc-devs/pytensor/issues/334 is fixed
381
+ measurable_ir_rewrites_db .register ("local_max_and_argmax" , local_max_and_argmax , "basic" )
382
+
379
383
# These rewrites push random/measurable variables "down", making them closer to
380
384
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
381
385
# "up" through the random/measurable variables and into their inputs.
Original file line number Diff line number Diff line change 45
45
import pymc as pm
46
46
47
47
from pymc import logp
48
+ from pymc .logprob import conditional_logp
48
49
from pymc .testing import assert_no_rvs
49
50
50
51
@@ -293,3 +294,18 @@ def test_min_max_bernoulli():
293
294
min_logp_fn = pytensor .function ([value ], pm .logp (pt .min (x ), value ))
294
295
np .testing .assert_allclose (min_logp_fn (1 ), np .log (p ** n ))
295
296
np .testing .assert_allclose (min_logp_fn (0 ), np .log (1 - p ** n ))
297
+
298
+
299
+ def test_non_measurable_max_grad ():
300
+ # Regression test for https://github.com/pymc-devs/pytensor/issues/711
301
+ x = pt .random .normal (0 , 1 , size = (3 ,))
302
+ max_x = x .max ()
303
+ y = pt .random .normal (max_x , 1 )
304
+
305
+ x_vv = x .type ()
306
+ y_vv = y .type ()
307
+ logp_terms = conditional_logp ({x : x_vv , y : y_vv }).values ()
308
+ joint_logp = pt .sum ([term .sum () for term in logp_terms ])
309
+
310
+ # Test that calling gradient does not raise a NotImplementedError
311
+ assert pt .grad (joint_logp , x_vv )
You can’t perform that action at this time.
0 commit comments