Skip to content

Commit 9f15330

Browse files
committed
Allow defining mode from compile_* functions
1 parent d8816c7 commit 9f15330

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

pymc/model/core.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ def compile_logp(
612612
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
613613
jacobian: bool = True,
614614
sum: bool = True,
615+
**compile_kwargs,
615616
) -> PointFunc:
616617
"""Compiled log probability density function.
617618
@@ -626,12 +627,13 @@ def compile_logp(
626627
Whether to sum all logp terms or return elemwise logp for each variable.
627628
Defaults to True.
628629
"""
629-
return self.compile_fn(self.logp(vars=vars, jacobian=jacobian, sum=sum))
630+
return self.compile_fn(self.logp(vars=vars, jacobian=jacobian, sum=sum), **compile_kwargs)
630631

631632
def compile_dlogp(
632633
self,
633634
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
634635
jacobian: bool = True,
636+
**compile_kwargs,
635637
) -> PointFunc:
636638
"""Compiled log probability density gradient function.
637639
@@ -643,12 +645,13 @@ def compile_dlogp(
643645
jacobian : bool
644646
Whether to include jacobian terms in logprob graph. Defaults to True.
645647
"""
646-
return self.compile_fn(self.dlogp(vars=vars, jacobian=jacobian))
648+
return self.compile_fn(self.dlogp(vars=vars, jacobian=jacobian), **compile_kwargs)
647649

648650
def compile_d2logp(
649651
self,
650652
vars: Optional[Union[Variable, Sequence[Variable]]] = None,
651653
jacobian: bool = True,
654+
**compile_kwargs,
652655
) -> PointFunc:
653656
"""Compiled log probability density hessian function.
654657
@@ -660,7 +663,7 @@ def compile_d2logp(
660663
jacobian : bool
661664
Whether to include jacobian terms in logprob graph. Defaults to True.
662665
"""
663-
return self.compile_fn(self.d2logp(vars=vars, jacobian=jacobian))
666+
return self.compile_fn(self.d2logp(vars=vars, jacobian=jacobian), **compile_kwargs)
664667

665668
def logp(
666669
self,

0 commit comments

Comments
 (0)