4141from pytensor .graph .fg import FunctionGraph
4242from pytensor .graph .rewriting .basic import node_rewriter
4343from pytensor .tensor .math import Max
44+ from pytensor .tensor .random .op import RandomVariable
45+ from pytensor .tensor .sort import SortOp
4446from pytensor .tensor .variable import TensorVariable
4547
4648from pymc .logprob .abstract import (
47- MeasurableElemwise ,
4849 MeasurableOp ,
4950 _logcdf_helper ,
5051 _logprob ,
5152 _logprob_helper ,
5253)
5354from pymc .logprob .rewriting import measurable_ir_rewrites_db
54- from pymc .logprob .utils import filter_measurable_variables
55+ from pymc .logprob .utils import (
56+ CheckParameterValue ,
57+ check_potential_measurability ,
58+ filter_measurable_variables ,
59+ )
5560from pymc .math import logdiffexp
5661from pymc .pytensorf import constant_fold
5762
5863
64+ def _underlying_iid_rv (variable ) -> TensorVariable | None :
65+ # Check whether an IID base RV is connected to the variable through identical elemwise operations
66+ from pymc .distributions .distribution import SymbolicRandomVariable
67+ from pymc .logprob .transforms import MeasurableTransform
68+
69+ def iid_elemwise_root (var : TensorVariable ) -> TensorVariable | None :
70+ node = var .owner
71+ if isinstance (node .op , RandomVariable | SymbolicRandomVariable ):
72+ return var
73+ elif isinstance (node .op , MeasurableTransform ):
74+ if len (node .inputs == 1 ):
75+ return iid_elemwise_root (node .inputs [0 ])
76+ else :
77+ # If the non-measurable inputs are broadcasted, it is still an IID operation.
78+ measurable_inp = node .op .measurable_input_idx
79+ other_inputs = [inp for i , inp in node .inputs if i != measurable_inp ]
80+ if all (all (other_inp .type .broadcastable ) for other_inp in other_inputs ):
81+ return iid_elemwise_root (node .inputs [measurable_inp ])
82+ return None
83+
84+ # Check that the root is a univariate distribution linked by only elemwise operations
85+ latent_base_var = iid_elemwise_root (variable )
86+
87+ if latent_base_var is None :
88+ return None
89+
90+ latent_op = latent_base_var .owner .op
91+
92+ if not (hasattr (latent_op , "dist_params" ) and getattr (latent_op , "ndim_supp" ) == 0 ):
93+ return None
94+
95+ if not all (
96+ all (params .type .broadcastable ) for params in latent_op .dist_params (latent_base_var .owner )
97+ ):
98+ return None
99+
100+ return cast (TensorVariable , latent_base_var )
101+
102+
59103class MeasurableMax (MeasurableOp , Max ):
60104 """A placeholder used to specify a log-likelihood for a max sub-graph."""
61105
@@ -77,31 +121,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
77121 if not filter_measurable_variables (node .inputs ):
78122 return None
79123
80- # We allow Max of RandomVariables or Elemwise of univariate RandomVariables
81- if isinstance (base_var .owner .op , MeasurableElemwise ):
82- latent_base_vars = [
83- var
84- for var in base_var .owner .inputs
85- if (var .owner and isinstance (var .owner .op , MeasurableOp ))
86- ]
87- if len (latent_base_vars ) != 1 :
88- return None
89- [latent_base_var ] = latent_base_vars
90- else :
91- latent_base_var = base_var
92-
93- latent_op = latent_base_var .owner .op
94- if not (hasattr (latent_op , "dist_params" ) and getattr (latent_op , "ndim_supp" ) == 0 ):
95- return None
124+ # We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables
125+ latent_base_var = _underlying_iid_rv (base_var )
96126
97- # univariate i.i.d. test which also rules out other distributions
98- if not all (
99- all (params .type .broadcastable ) for params in latent_op .dist_params (latent_base_var .owner )
100- ):
127+ if not latent_base_var :
101128 return None
102129
103- base_var = cast (TensorVariable , base_var )
104-
105130 if node .op .axis is None :
106131 axis = tuple (range (base_var .ndim ))
107132 else :
@@ -119,7 +144,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
119144
120145
121146measurable_ir_rewrites_db .register (
122- " find_measurable_max" ,
147+ find_measurable_max . __name__ ,
123148 find_measurable_max ,
124149 "basic" ,
125150 "max" ,
@@ -158,3 +183,54 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
158183
159184 n = pt .prod (base_rv_shape )
160185 return logdiffexp (n * logcdf , n * logcdf_prev )
186+
187+
188+ class MeasurableSort (MeasurableOp , SortOp ):
189+ """A placeholder used to specify a log-likelihood for a sort sub-graph."""
190+
191+
192+ @_logprob .register (MeasurableSort )
193+ def sort_logprob (op , values , base_rv , axis , ** kwargs ):
194+ r"""Compute the log-likelihood graph for the `Sort` operation."""
195+ (value ,) = values
196+
197+ logprob = _logprob_helper (base_rv , value ).sum (axis = - 1 )
198+
199+ base_rv_shape = constant_fold (tuple (base_rv .shape ), raise_not_constant = False )
200+ n = pt .prod (base_rv_shape , axis = - 1 )
201+ sorted_logp = pt .gammaln (n + 1 ) + logprob
202+
203+ # The sorted value is not really a parameter, but we include the check in
204+ # `CheckParameterValue` to avoid costly sorting if `check_bounds=False` in a PyMC model
205+ return CheckParameterValue ("value must be sorted" , can_be_replaced_by_ninf = True )(
206+ sorted_logp , pt .eq (value , value .sort (axis = axis , kind = op .kind )).all ()
207+ )
208+
209+
210+ @node_rewriter (tracks = [SortOp ])
211+ def find_measurable_sort (fgraph , node ):
212+ if isinstance (node .op , MeasurableSort ):
213+ return None
214+
215+ if not filter_measurable_variables (node .inputs ):
216+ return None
217+
218+ [base_var , axis ] = node .inputs
219+
220+ # We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables
221+ if _underlying_iid_rv (base_var ) is None :
222+ return None
223+
224+ # Check axis is not potentially measurable
225+ if check_potential_measurability ([axis ]):
226+ return None
227+
228+ return [MeasurableSort (** node .op ._props_dict ())(base_var , axis )]
229+
230+
231+ measurable_ir_rewrites_db .register (
232+ find_measurable_sort .__name__ ,
233+ find_measurable_sort ,
234+ "basic" ,
235+ "sort" ,
236+ )
0 commit comments