39
39
from copy import copy
40
40
from typing import Callable , Dict , List , Optional , Sequence , Tuple , Union
41
41
42
+ import numpy as np
42
43
import pytensor .tensor as at
43
44
44
45
from pytensor .gradient import DisconnectedType , jacobian
48
49
from pytensor .graph .op import Op
49
50
from pytensor .graph .replace import clone_replace
50
51
from pytensor .graph .rewriting .basic import GraphRewriter , in2out , node_rewriter
51
- from pytensor .scalar import Add , Exp , Log , Mul , Reciprocal
52
+ from pytensor .scalar import Add , Exp , Log , Mul , Pow , Sqr , Sqrt
52
53
from pytensor .scan .op import Scan
53
54
from pytensor .tensor .exceptions import NotScalarConstantError
54
- from pytensor .tensor .math import add , exp , log , mul , neg , reciprocal , sub , true_div
55
+ from pytensor .tensor .math import (
56
+ add ,
57
+ exp ,
58
+ log ,
59
+ mul ,
60
+ neg ,
61
+ pow ,
62
+ reciprocal ,
63
+ sqr ,
64
+ sqrt ,
65
+ sub ,
66
+ true_div ,
67
+ )
55
68
from pytensor .tensor .rewriting .basic import (
56
69
register_specialize ,
57
70
register_stabilize ,
@@ -110,8 +123,11 @@ def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
110
123
"""Apply the transformation."""
111
124
112
125
@abc .abstractmethod
113
- def backward (self , value : TensorVariable , * inputs : Variable ) -> TensorVariable :
114
- """Invert the transformation."""
126
+ def backward (
127
+ self , value : TensorVariable , * inputs : Variable
128
+ ) -> Union [TensorVariable , Tuple [TensorVariable , ...]]:
129
+ """Invert the transformation. Multiple values may be returned when the
130
+ transformation is not 1-to-1"""
115
131
116
132
def log_jac_det (self , value : TensorVariable , * inputs ) -> TensorVariable :
117
133
"""Construct the log of the absolute value of the Jacobian determinant."""
@@ -320,7 +336,7 @@ def apply(self, fgraph: FunctionGraph):
320
336
class MeasurableTransform (MeasurableElemwise ):
321
337
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""
322
338
323
- valid_scalar_types = (Exp , Log , Add , Mul , Reciprocal )
339
+ valid_scalar_types = (Exp , Log , Add , Mul , Pow )
324
340
325
341
# Cannot use `transform` as name because it would clash with the property added by
326
342
# the `TransformValuesRewrite`
@@ -349,16 +365,64 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
349
365
# The value variable must still be back-transformed to be on the natural support of
350
366
# the respective measurable input.
351
367
backward_value = op .transform_elemwise .backward (value , * other_inputs )
352
- input_logprob = logprob (measurable_input , backward_value , ** kwargs )
368
+
369
+ # Some transformations, like squaring may produce multiple backward values
370
+ if isinstance (backward_value , tuple ):
371
+ input_logprob = at .logaddexp (
372
+ * (logprob (measurable_input , backward_val , ** kwargs ) for backward_val in backward_value )
373
+ )
374
+ else :
375
+ input_logprob = logprob (measurable_input , backward_value )
353
376
354
377
jacobian = op .transform_elemwise .log_jac_det (value , * other_inputs )
355
378
356
379
return input_logprob + jacobian
357
380
358
381
382
+ @node_rewriter ([reciprocal ])
383
+ def measurable_reciprocal_to_power (fgraph , node ):
384
+ """Convert reciprocal of `MeasurableVariable`s to power."""
385
+ inp = node .inputs [0 ]
386
+ if not (inp .owner and isinstance (inp .owner .op , MeasurableVariable )):
387
+ return None
388
+
389
+ rv_map_feature : Optional [PreserveRVMappings ] = getattr (fgraph , "preserve_rv_mappings" , None )
390
+ if rv_map_feature is None :
391
+ return None # pragma: no cover
392
+
393
+ # Only apply this rewrite if the variable is unvalued
394
+ if inp in rv_map_feature .rv_values :
395
+ return None # pragma: no cover
396
+
397
+ return [at .pow (inp , - 1.0 )]
398
+
399
+
400
+ @node_rewriter ([sqr , sqrt ])
401
+ def measurable_sqrt_sqr_to_power (fgraph , node ):
402
+ """Convert square root or square of `MeasurableVariable`s to power form."""
403
+
404
+ inp = node .inputs [0 ]
405
+ if not (inp .owner and isinstance (inp .owner .op , MeasurableVariable )):
406
+ return None
407
+
408
+ rv_map_feature : Optional [PreserveRVMappings ] = getattr (fgraph , "preserve_rv_mappings" , None )
409
+ if rv_map_feature is None :
410
+ return None # pragma: no cover
411
+
412
+ # Only apply this rewrite if the variable is unvalued
413
+ if inp in rv_map_feature .rv_values :
414
+ return None # pragma: no cover
415
+
416
+ if isinstance (node .op .scalar_op , Sqr ):
417
+ return [at .pow (inp , 2 )]
418
+
419
+ if isinstance (node .op .scalar_op , Sqrt ):
420
+ return [at .pow (inp , 1 / 2 )]
421
+
422
+
359
423
@node_rewriter ([true_div ])
360
- def measurable_div_to_reciprocal_product (fgraph , node ):
361
- """Convert divisions involving `MeasurableVariable`s to product with reciprocal ."""
424
+ def measurable_div_to_product (fgraph , node ):
425
+ """Convert divisions involving `MeasurableVariable`s to products ."""
362
426
363
427
measurable_vars = [
364
428
var for var in node .inputs if (var .owner and isinstance (var .owner .op , MeasurableVariable ))
@@ -379,9 +443,13 @@ def measurable_div_to_reciprocal_product(fgraph, node):
379
443
# Check if numerator is 1
380
444
try :
381
445
if at .get_scalar_constant_value (numerator ) == 1 :
382
- return [at .reciprocal (denominator )]
446
+ # We convert the denominator directly to a power transform as this
447
+ # must be the measurable input
448
+ return [at .pow (denominator , - 1 )]
383
449
except NotScalarConstantError :
384
450
pass
451
+ # We don't convert the denominator directly to a power transform as
452
+ # it might not be measurable (and therefore not needed)
385
453
return [at .mul (numerator , at .reciprocal (denominator ))]
386
454
387
455
@@ -425,7 +493,7 @@ def measurable_sub_to_neg(fgraph, node):
425
493
return [at .add (minuend , at .neg (subtrahend ))]
426
494
427
495
428
- @node_rewriter ([exp , log , add , mul , reciprocal ])
496
+ @node_rewriter ([exp , log , add , mul , pow ])
429
497
def find_measurable_transforms (fgraph : FunctionGraph , node : Node ) -> Optional [List [Node ]]:
430
498
"""Find measurable transformations from Elemwise operators."""
431
499
@@ -485,8 +553,18 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
485
553
transform = ExpTransform ()
486
554
elif isinstance (scalar_op , Log ):
487
555
transform = LogTransform ()
488
- elif isinstance (scalar_op , Reciprocal ):
489
- transform = ReciprocalTransform ()
556
+ elif isinstance (scalar_op , Pow ):
557
+ # We only allow for the base to be measurable
558
+ if measurable_input_idx != 0 :
559
+ return None
560
+ try :
561
+ (power ,) = other_inputs
562
+ power = at .get_scalar_constant_value (power ).item ()
563
+ # Power needs to be a constant
564
+ except NotScalarConstantError :
565
+ return None
566
+ transform_inputs = (measurable_input , power )
567
+ transform = PowerTransform (power = power )
490
568
elif isinstance (scalar_op , Add ):
491
569
transform_inputs = (measurable_input , at .add (* other_inputs ))
492
570
transform = LocTransform (
@@ -510,12 +588,29 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
510
588
511
589
512
590
measurable_ir_rewrites_db .register (
513
- "measurable_div_to_reciprocal_product " ,
514
- measurable_div_to_reciprocal_product ,
591
+ "measurable_reciprocal_to_power " ,
592
+ measurable_reciprocal_to_power ,
515
593
"basic" ,
516
594
"transform" ,
517
595
)
518
596
597
+
598
+ measurable_ir_rewrites_db .register (
599
+ "measurable_sqrt_sqr_to_power" ,
600
+ measurable_sqrt_sqr_to_power ,
601
+ "basic" ,
602
+ "transform" ,
603
+ )
604
+
605
+
606
+ measurable_ir_rewrites_db .register (
607
+ "measurable_div_to_product" ,
608
+ measurable_div_to_product ,
609
+ "basic" ,
610
+ "transform" ,
611
+ )
612
+
613
+
519
614
measurable_ir_rewrites_db .register (
520
615
"measurable_neg_to_product" ,
521
616
measurable_neg_to_product ,
@@ -601,17 +696,33 @@ def log_jac_det(self, value, *inputs):
601
696
return - at .log (value )
602
697
603
698
604
- class ReciprocalTransform (RVTransform ):
605
- name = "reciprocal"
699
+ class PowerTransform (RVTransform ):
700
+ name = "power"
701
+
702
+ def __init__ (self , power = None ):
703
+ if not isinstance (power , (int , float )):
704
+ raise TypeError (f"Power must be integer or float, got { type (power )} " )
705
+ if power == 0 :
706
+ raise ValueError ("Power cannot be 0" )
707
+ self .power = power
708
+ super ().__init__ ()
606
709
607
710
def forward (self , value , * inputs ):
608
- return at .reciprocal (value )
711
+ at .power (value , self . power )
609
712
610
713
def backward (self , value , * inputs ):
611
- return at .reciprocal (value )
714
+ backward_value = at .power (value , (1 / self .power ))
715
+
716
+ # In this case the transform is not 1-to-1
717
+ if (self .power > 1 ) and (self .power % 2 == 0 ):
718
+ return - backward_value , backward_value
719
+ else :
720
+ return backward_value
612
721
613
722
def log_jac_det (self , value , * inputs ):
614
- return - 2 * at .log (value )
723
+ inv_power = 1 / self .power
724
+ # Note: This fails for value==0
725
+ return np .log (np .abs (inv_power )) + (inv_power - 1 ) * at .log (value )
615
726
616
727
617
728
class IntervalTransform (RVTransform ):
0 commit comments