61
61
gradient ,
62
62
hessian ,
63
63
inputvars ,
64
+ join_nonshared_inputs ,
64
65
rewrite_pregrad ,
65
66
)
66
67
from pymc .util import (
@@ -172,6 +173,9 @@ def __init__(
172
173
dtype = None ,
173
174
casting = "no" ,
174
175
compute_grads = True ,
176
+ model = None ,
177
+ initial_point = None ,
178
+ ravel_inputs : bool | None = None ,
175
179
** kwargs ,
176
180
):
177
181
if extra_vars_and_values is None :
@@ -219,9 +223,7 @@ def __init__(
219
223
givens = []
220
224
self ._extra_vars_shared = {}
221
225
for var , value in extra_vars_and_values .items ():
222
- shared = pytensor .shared (
223
- value , var .name + "_shared__" , shape = [1 if s == 1 else None for s in value .shape ]
224
- )
226
+ shared = pytensor .shared (value , var .name + "_shared__" , shape = value .shape )
225
227
self ._extra_vars_shared [var .name ] = shared
226
228
givens .append ((var , shared ))
227
229
@@ -231,13 +233,28 @@ def __init__(
231
233
grads = pytensor .grad (cost , grad_vars , disconnected_inputs = "ignore" )
232
234
for grad_wrt , var in zip (grads , grad_vars ):
233
235
grad_wrt .name = f"{ var .name } _grad"
234
- outputs = [cost , * grads ]
236
+ grads = pt .join (0 , * [pt .atleast_1d (grad .ravel ()) for grad in grads ])
237
+ outputs = [cost , grads ]
235
238
else :
236
239
outputs = [cost ]
237
240
238
- inputs = grad_vars
241
+ if ravel_inputs :
242
+ if initial_point is None :
243
+ initial_point = modelcontext (model ).initial_point ()
244
+ outputs , raveled_grad_vars = join_nonshared_inputs (
245
+ point = initial_point , inputs = grad_vars , outputs = outputs , make_inputs_shared = False
246
+ )
247
+ inputs = [raveled_grad_vars ]
248
+ else :
249
+ if ravel_inputs is None :
250
+ warnings .warn (
251
+ "ValueGradFunction will become a function of raveled inputs.\n "
252
+ "Specify `ravel_inputs` to suppress this warning. Note that setting `ravel_inputs=False` will be forbidden in a future release."
253
+ )
254
+ inputs = grad_vars
239
255
240
256
self ._pytensor_function = compile_pymc (inputs , outputs , givens = givens , ** kwargs )
257
+ self ._raveled_inputs = ravel_inputs
241
258
242
259
def set_weights (self , values ):
243
260
if values .shape != (self ._n_costs - 1 ,):
@@ -247,38 +264,29 @@ def set_weights(self, values):
247
264
def set_extra_values (self , extra_vars ):
248
265
self ._extra_are_set = True
249
266
for var in self ._extra_vars :
250
- self ._extra_vars_shared [var .name ].set_value (extra_vars [var .name ])
267
+ self ._extra_vars_shared [var .name ].set_value (extra_vars [var .name ], borrow = True )
251
268
252
269
def get_extra_values (self ):
253
270
if not self ._extra_are_set :
254
271
raise ValueError ("Extra values are not set." )
255
272
256
273
return {var .name : self ._extra_vars_shared [var .name ].get_value () for var in self ._extra_vars }
257
274
258
- def __call__ (self , grad_vars , grad_out = None , extra_vars = None ):
275
+ def __call__ (self , grad_vars , * , extra_vars = None ):
259
276
if extra_vars is not None :
260
277
self .set_extra_values (extra_vars )
261
-
262
- if not self ._extra_are_set :
278
+ elif not self ._extra_are_set :
263
279
raise ValueError ("Extra values are not set." )
264
280
265
281
if isinstance (grad_vars , RaveledVars ):
266
- grad_vars = list (DictToArrayBijection .rmap (grad_vars ).values ())
267
-
268
- cost , * grads = self ._pytensor_function (* grad_vars )
269
-
270
- if grads :
271
- grads_raveled = DictToArrayBijection .map (
272
- {v .name : gv for v , gv in zip (self ._grad_vars , grads )}
273
- )
274
-
275
- if grad_out is None :
276
- return cost , grads_raveled .data
282
+ if self ._raveled_inputs :
283
+ grad_vars = (grad_vars .data ,)
277
284
else :
278
- np .copyto (grad_out , grads_raveled .data )
279
- return cost
280
- else :
281
- return cost
285
+ grad_vars = DictToArrayBijection .rmap (grad_vars ).values ()
286
+ elif self ._raveled_inputs and not isinstance (grad_vars , Sequence ):
287
+ grad_vars = (grad_vars ,)
288
+
289
+ return self ._pytensor_function (* grad_vars )
282
290
283
291
@property
284
292
def profile (self ):
@@ -521,7 +529,14 @@ def root(self):
521
529
def isroot (self ):
522
530
return self .parent is None
523
531
524
- def logp_dlogp_function (self , grad_vars = None , tempered = False , ** kwargs ):
532
+ def logp_dlogp_function (
533
+ self ,
534
+ grad_vars = None ,
535
+ tempered = False ,
536
+ initial_point = None ,
537
+ ravel_inputs : bool | None = None ,
538
+ ** kwargs ,
539
+ ):
525
540
"""Compile a PyTensor function that computes logp and gradient.
526
541
527
542
Parameters
@@ -547,13 +562,22 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
547
562
costs = [self .logp ()]
548
563
549
564
input_vars = {i for i in graph_inputs (costs ) if not isinstance (i , Constant )}
550
- ip = self .initial_point (0 )
565
+ if initial_point is None :
566
+ initial_point = self .initial_point (0 )
551
567
extra_vars_and_values = {
552
- var : ip [var .name ]
568
+ var : initial_point [var .name ]
553
569
for var in self .value_vars
554
570
if var in input_vars and var not in grad_vars
555
571
}
556
- return ValueGradFunction (costs , grad_vars , extra_vars_and_values , ** kwargs )
572
+ return ValueGradFunction (
573
+ costs ,
574
+ grad_vars ,
575
+ extra_vars_and_values ,
576
+ model = self ,
577
+ initial_point = initial_point ,
578
+ ravel_inputs = ravel_inputs ,
579
+ ** kwargs ,
580
+ )
557
581
558
582
def compile_logp (
559
583
self ,
0 commit comments