@@ -612,6 +612,7 @@ def compile_logp(
612
612
vars : Optional [Union [Variable , Sequence [Variable ]]] = None ,
613
613
jacobian : bool = True ,
614
614
sum : bool = True ,
615
+ ** compile_kwargs ,
615
616
) -> PointFunc :
616
617
"""Compiled log probability density function.
617
618
@@ -626,12 +627,13 @@ def compile_logp(
626
627
Whether to sum all logp terms or return elemwise logp for each variable.
627
628
Defaults to True.
628
629
"""
629
- return self .compile_fn (self .logp (vars = vars , jacobian = jacobian , sum = sum ))
630
+ return self .compile_fn (self .logp (vars = vars , jacobian = jacobian , sum = sum ), ** compile_kwargs )
630
631
631
632
def compile_dlogp (
632
633
self ,
633
634
vars : Optional [Union [Variable , Sequence [Variable ]]] = None ,
634
635
jacobian : bool = True ,
636
+ ** compile_kwargs ,
635
637
) -> PointFunc :
636
638
"""Compiled log probability density gradient function.
637
639
@@ -643,12 +645,13 @@ def compile_dlogp(
643
645
jacobian : bool
644
646
Whether to include jacobian terms in logprob graph. Defaults to True.
645
647
"""
646
- return self .compile_fn (self .dlogp (vars = vars , jacobian = jacobian ))
648
+ return self .compile_fn (self .dlogp (vars = vars , jacobian = jacobian ), ** compile_kwargs )
647
649
648
650
def compile_d2logp (
649
651
self ,
650
652
vars : Optional [Union [Variable , Sequence [Variable ]]] = None ,
651
653
jacobian : bool = True ,
654
+ ** compile_kwargs ,
652
655
) -> PointFunc :
653
656
"""Compiled log probability density hessian function.
654
657
@@ -660,7 +663,7 @@ def compile_d2logp(
660
663
jacobian : bool
661
664
Whether to include jacobian terms in logprob graph. Defaults to True.
662
665
"""
663
- return self .compile_fn (self .d2logp (vars = vars , jacobian = jacobian ))
666
+ return self .compile_fn (self .d2logp (vars = vars , jacobian = jacobian ), ** compile_kwargs )
664
667
665
668
def logp (
666
669
self ,
0 commit comments