Skip to content

Commit 9b1cd0e

Browse files
Update docstring
1 parent 4d88343 commit 9b1cd0e

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

pymc_experimental/inference/find_map.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
9292
f_untransform = pytensor.function(
9393
inputs=[pytensor.In(X, borrow=True)],
9494
outputs=pytensor.Out(out, borrow=True),
95-
mode=Mode(linker="py", optimizer=None),
95+
mode=Mode(linker="py", optimizer="FAST_COMPILE"),
9696
)
9797
return f_untransform(posterior_draws)
9898

@@ -223,7 +223,6 @@ def scipy_optimize_funcs_from_loss(
223223
"""
224224
Compile loss functions for use with scipy.optimize.minimize.
225225
226-
227226
Parameters
228227
----------
229228
loss: TensorVariable
@@ -238,8 +237,8 @@ def scipy_optimize_funcs_from_loss(
238237
Whether to compile a function that computes the Hessian of the loss function.
239238
use_hessp: bool
240239
Whether to compile a function that computes the Hessian-vector product of the loss function.
241-
gradient_backend: str, one of "jax" or "pytensor"
242-
Which backend to use to compute gradients.
240+
gradient_backend: str, default "pytensor"
241+
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
243242
compile_kwargs:
244243
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
245244

0 commit comments

Comments
 (0)