1616from functools import singledispatch
1717
1818import numpy as np
19+ import pytensor
1920import pytensor .tensor as pt
2021
2122
@@ -179,6 +180,33 @@ def __init__(self, n):
179180 """
180181 self .n = n
181182
183+ def step (self , i , counter , L , y ):
184+ y_star = y [counter : counter + i ]
185+ dsy = y_star .dot (y_star )
186+ alpha_r = 1 / (dsy + 1 )
187+ gamma = pt .sqrt (dsy + 2 ) * alpha_r
188+
189+ x = pt .join (0 , gamma * y_star , pt .atleast_1d (alpha_r ))
190+ next_L = L [i , : i + 1 ].set (x )
191+ log_det = pt .log (2 ) + 0.5 * (i - 2 ) * pt .log (dsy + 2 ) - i * pt .log (1 + dsy )
192+
193+ return next_L , log_det
194+
195+ def _compute_L_and_logdet_scan (self , value , * inputs ):
196+ L = pt .eye (self .n )
197+ idxs = pt .arange (1 , self .n )
198+ counters = pt .arange (0 , self .n ).cumsum ()
199+
200+ results , _ = pytensor .scan (
201+ self .step , outputs_info = [L , None ], sequences = [idxs , counters ], non_sequences = [value ]
202+ )
203+
204+ L_seq , log_det_seq = results
205+ L = L_seq [- 1 ]
206+ log_det = pt .sum (log_det_seq )
207+
208+ return L , log_det
209+
182210 def _compute_L_and_logdet (self , value , * inputs ):
183211 n = self .n
184212 counter = 0
@@ -201,7 +229,7 @@ def _compute_L_and_logdet(self, value, *inputs):
201229 return L , log_det
202230
203231 def backward (self , value , * inputs ):
204- L , _ = self ._compute_L_and_logdet (value , * inputs )
232+ L , _ = self ._compute_L_and_logdet_scan (value , * inputs )
205233 return L
206234
207235 def forward (self , value , * inputs ):
@@ -211,7 +239,7 @@ def forward(self, value, *inputs):
211239 return pt .as_tensor_variable (np .random .normal (size = size ))
212240
213241 def log_jac_det (self , value , * inputs ):
214- _ , log_det = self ._compute_L_and_logdet (value , * inputs )
242+ _ , log_det = self ._compute_L_and_logdet_scan (value , * inputs )
215243 return log_det
216244
217245
0 commit comments