42
42
import pytensor .tensor as pt
43
43
44
44
from pytensor .graph .basic import Variable
45
- from pytensor .graph .fg import FunctionGraph
46
45
from pytensor .graph .op import compute_test_value
47
46
from pytensor .graph .rewriting .basic import node_rewriter
48
47
from pytensor .graph .rewriting .db import RewriteDatabaseQuery
49
48
from pytensor .scan .op import Scan
50
49
from pytensor .scan .rewriting import scan_eqopt1 , scan_eqopt2
51
50
from pytensor .scan .utils import ScanArgs
52
51
from pytensor .tensor .random .type import RandomType
53
- from pytensor .tensor .rewriting .shape import ShapeFeature
54
52
from pytensor .tensor .subtensor import Subtensor , indices_from_subtensor
55
53
from pytensor .tensor .var import TensorVariable
56
54
from pytensor .updates import OrderedUpdates
63
61
)
64
62
from pymc .logprob .joint_logprob import factorized_joint_logprob
65
63
from pymc .logprob .rewriting import (
66
- PreserveRVMappings ,
64
+ construct_ir_fgraph ,
67
65
inc_subtensor_ops ,
68
66
logprob_rewrites_db ,
69
67
measurable_ir_rewrites_db ,
70
68
)
69
+ from pymc .pytensorf import replace_rvs_by_values
71
70
72
71
73
72
class MeasurableScan (Scan ):
@@ -249,9 +248,27 @@ def remove(x, i):
249
248
new_inner_out_nit_sot = tuple (output_scan_args .inner_out_nit_sot ) + tuple (
250
249
inner_out_fn (remapped_io_to_ii )
251
250
)
252
-
253
251
output_scan_args .inner_out_nit_sot = list (new_inner_out_nit_sot )
254
252
253
+ # Finally, we need to replace any lingering references to the new
254
+ # internal variables that could be in the recurrent states needed
255
+ # to compute the new nit_sots
256
+ traced_outs = (
257
+ output_scan_args .inner_out_mit_sot
258
+ + output_scan_args .inner_out_sit_sot
259
+ + output_scan_args .inner_out_nit_sot
260
+ )
261
+ traced_outs = replace_rvs_by_values (traced_outs , rvs_to_values = remapped_io_to_ii )
262
+ # Update output mappings
263
+ n_mit_sot = len (output_scan_args .inner_out_mit_sot )
264
+ output_scan_args .inner_out_mit_sot = traced_outs [:n_mit_sot ]
265
+ offset = n_mit_sot
266
+ n_sit_sot = len (output_scan_args .inner_out_sit_sot )
267
+ output_scan_args .inner_out_sit_sot = traced_outs [offset : offset + n_sit_sot ]
268
+ offset += n_sit_sot
269
+ n_nit_sot = len (output_scan_args .inner_out_nit_sot )
270
+ output_scan_args .inner_out_nit_sot = traced_outs [offset : offset + n_nit_sot ]
271
+
255
272
return output_scan_args
256
273
257
274
@@ -331,7 +348,10 @@ def create_inner_out_logp(value_map: Dict[TensorVariable, TensorVariable]) -> Te
331
348
for key , value in updates .items ():
332
349
key .default_update = value
333
350
334
- return logp_scan_out
351
+ # Return only the logp outputs, not any potentially carried states
352
+ logp_outputs = logp_scan_out [- len (values ) :]
353
+
354
+ return logp_outputs
335
355
336
356
337
357
@node_rewriter ([Scan ])
@@ -504,19 +524,9 @@ def add_opts_to_inner_graphs(fgraph, node):
504
524
if getattr (node .op .mode , "had_logprob_rewrites" , False ):
505
525
return None
506
526
507
- inner_fgraph = FunctionGraph (
508
- node .op .inner_inputs ,
509
- node .op .inner_outputs ,
510
- clone = True ,
511
- copy_inputs = False ,
512
- copy_orphans = False ,
513
- features = [
514
- ShapeFeature (),
515
- PreserveRVMappings ({}),
516
- ],
517
- )
518
-
519
- logprob_rewrites_db .query (RewriteDatabaseQuery (include = ["basic" ])).rewrite (inner_fgraph )
527
+ inner_rv_values = {out : out .type () for out in node .op .inner_outputs }
528
+ ir_rewriter = logprob_rewrites_db .query (RewriteDatabaseQuery (include = ["basic" ]))
529
+ inner_fgraph , rv_values , _ = construct_ir_fgraph (inner_rv_values , ir_rewriter = ir_rewriter )
520
530
521
531
new_outputs = list (inner_fgraph .outputs )
522
532
@@ -531,11 +541,23 @@ def add_opts_to_inner_graphs(fgraph, node):
531
541
532
542
533
543
@_get_measurable_outputs .register (MeasurableScan )
534
- def _get_measurable_outputs_MeasurableScan (op , node ):
535
- # TODO: This should probably use `get_random_outer_outputs`
536
- # scan_args = ScanArgs.from_node(node)
537
- # rv_outer_outs = get_random_outer_outputs(scan_args)
538
- return [o for o in node .outputs if not isinstance (o .type , RandomType )]
544
+ def _get_measurable_outputs_MeasurableScan (op : Scan , node ):
545
+ """Collect measurable outputs for Measurable Scans"""
546
+ inner_out_from_outer_out_map = op .get_oinp_iinp_iout_oout_mappings ()["inner_out_from_outer_out" ]
547
+ inner_outs = op .inner_outputs
548
+
549
+ # Measurable scan outputs are those whose inner scan output counterparts are also measurable
550
+ measurable_outputs = []
551
+ for out_idx , out in enumerate (node .outputs ):
552
+ [inner_out_idx ] = inner_out_from_outer_out_map [out_idx ]
553
+ inner_out = inner_outs [inner_out_idx ]
554
+ inner_out_node = inner_out .owner
555
+ if isinstance (
556
+ inner_out_node .op , MeasurableVariable
557
+ ) and inner_out in get_measurable_outputs (inner_out_node .op , inner_out_node ):
558
+ measurable_outputs .append (out )
559
+
560
+ return measurable_outputs
539
561
540
562
541
563
measurable_ir_rewrites_db .register (
0 commit comments