|  | 
| 29 | 29 | import xarray as xr | 
| 30 | 30 | 
 | 
| 31 | 31 | from arviz import dict_to_dataset | 
| 32 |  | -from better_optimize.constants import minimize_method | 
|  | 32 | +from better_optimize.constants import minimize_method, root_method | 
| 33 | 33 | from pymc import DictToArrayBijection | 
| 34 | 34 | from pymc.backends.arviz import ( | 
| 35 | 35 |     coords_and_dims_for_inferencedata, | 
|  | 
| 41 | 41 | from pymc.model.transform.optimization import freeze_dims_and_data | 
| 42 | 42 | from pymc.util import get_default_varnames | 
| 43 | 43 | from pytensor.tensor import TensorVariable | 
| 44 |  | -from pytensor.tensor.optimize import minimize | 
|  | 44 | +from pytensor.tensor.optimize import root | 
| 45 | 45 | from scipy import stats | 
| 46 | 46 | 
 | 
| 47 | 47 | from pymc_extras.inference.find_map import ( | 
|  | 
| 55 | 55 | _log = logging.getLogger(__name__) | 
| 56 | 56 | 
 | 
| 57 | 57 | 
 | 
|  | 58 | +def find_mode_jac_hess( | 
|  | 59 | +    x: TensorVariable,  # Should be vector specifically | 
|  | 60 | +    Q: TensorVariable,  # Matrix # TODO tensorinv doesn't have grad implemented yet | 
|  | 61 | +    mu: TensorVariable,  # Vector | 
|  | 62 | +    model: pm.Model | None = None, | 
|  | 63 | +    method: root_method = "hybr", | 
|  | 64 | +    use_jac: bool = True, | 
|  | 65 | +    # use_hess: bool = False, | 
|  | 66 | +    optimizer_kwargs: dict | None = None, | 
|  | 67 | +) -> Callable: | 
|  | 68 | +    """ | 
|  | 69 | +    Returns a function to estimate the mode and both the first and second derivatives of a model at that point by minimizing negative log likelihood. Wrapper for (pytensor-native) scipy.optimize.minimize. | 
|  | 70 | +
 | 
|  | 71 | +    Parameters | 
|  | 72 | +    ---------- | 
|  | 73 | +    x: TensorVariable | 
|  | 74 | +        The parameter with which to minimize wrt (that is, find the mode in x). | 
|  | 75 | +    model: Model | 
|  | 76 | +        PyMC model to use. | 
|  | 77 | +    method: minimize_method | 
|  | 78 | +        Which minimization algorithm to use. | 
|  | 79 | +    use_jac: bool | 
|  | 80 | +        If true, the minimizer will compute and store the Jacobian. | 
|  | 81 | +    use_hess: bool | 
|  | 82 | +        If true, the minimizer will compute and store the Hessian (note that the Hessian will be computed explicitely even if this is False). | 
|  | 83 | +    optimizer_kwargs: dict | 
|  | 84 | +        Kwargs to pass to scipy.optimize.minimize. | 
|  | 85 | +
 | 
|  | 86 | +    Returns | 
|  | 87 | +    ------- | 
|  | 88 | +    f: Callable | 
|  | 89 | +        A function which accepts the values of the model RVs as args and returns [mu, jac(mu) hess(mu)], where mu is the mode. The TensorVariable x is specified as an initial guess for mu in args. | 
|  | 90 | +    """ | 
|  | 91 | +    model = pm.modelcontext(model) | 
|  | 92 | + | 
|  | 93 | +    # f = log(p(y | x, params)) | 
|  | 94 | +    f = model.logp() | 
|  | 95 | +    jac = pytensor.gradient.grad(f, x) | 
|  | 96 | +    hess = pytensor.gradient.jacobian(jac.flatten(), x) | 
|  | 97 | + | 
|  | 98 | +    # Component of log(p(x | y, params)) which depends on x (for rootfinding) | 
|  | 99 | +    conditional_gaussian_approx = -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x) | 
|  | 100 | + | 
|  | 101 | +    x0, _ = root( | 
|  | 102 | +        equations=pt.stack([conditional_gaussian_approx]), | 
|  | 103 | +        variables=x, | 
|  | 104 | +        method=method, | 
|  | 105 | +        jac=use_jac, | 
|  | 106 | +        optimizer_kwargs=optimizer_kwargs, | 
|  | 107 | +    ) | 
|  | 108 | + | 
|  | 109 | +    # require f'(x0) and f''(x0) for Laplace approx | 
|  | 110 | +    jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) | 
|  | 111 | +    hess = pytensor.graph.replace.graph_replace( | 
|  | 112 | +        hess, {x: x0} | 
|  | 113 | +    )  # Possibly unecessary because jac already does this replace | 
|  | 114 | + | 
|  | 115 | +    # Full log(p(x | y, params)) | 
|  | 116 | +    _, logdetQ = pt.nlinalg.slogdet(Q) | 
|  | 117 | +    conditional_gaussian_approx = ( | 
|  | 118 | +        -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ | 
|  | 119 | +    )  # TODO does doing this change the graph in root before if changed before it's compiled? | 
|  | 120 | + | 
|  | 121 | +    args = model.continuous_value_vars + model.discrete_value_vars | 
|  | 122 | +    return pytensor.function( | 
|  | 123 | +        args, [x0, conditional_gaussian_approx] | 
|  | 124 | +    )  # Currently x being passed in as an initial guess for x0 AND then also going to the true value of x | 
|  | 125 | + | 
|  | 126 | +    # Minimise negative log likelihood | 
|  | 127 | +    # nll = -model.logp() | 
|  | 128 | +    # soln, _ = minimize( | 
|  | 129 | +    #     objective=nll, | 
|  | 130 | +    #     x=x, | 
|  | 131 | +    #     method=method, | 
|  | 132 | +    #     jac=use_jac, | 
|  | 133 | +    #     hess=use_hess, | 
|  | 134 | +    #     optimizer_kwargs=optimizer_kwargs, | 
|  | 135 | +    # ) | 
|  | 136 | + | 
|  | 137 | +    # TODO: Jesse suggested I use this graph_replace function, but it seems that "mode" here is a different type to soln: | 
|  | 138 | +    # | 
|  | 139 | +    # TypeError: Cannot convert Type Vector(float64, shape=(10,)) (of Variable MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0) into Type Scalar(float64, shape=()). You can try to manually convert MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0 into a Scalar(float64, shape=()). | 
|  | 140 | +    # | 
|  | 141 | +    # My understanding here is that for some function which evaluates the hessian at x, we're replacing "x" in the hess graph with the subgraph that computes "x" (i.e. soln)? | 
|  | 142 | + | 
|  | 143 | +    # Obtain the Hessian (re-use graph if already computed in minimize) | 
|  | 144 | +    # if use_hess: | 
|  | 145 | +    #     mode, _, hess = ( | 
|  | 146 | +    #         soln.owner.op.inner_outputs | 
|  | 147 | +    #     )  # Note that this mode, _, hess will need to be slightly more elaborate for when use_jac is False (2 items to unpack instead of 3). Just a few if-blocks, but not implemented for now while we're debugging | 
|  | 148 | +    #     hess = pytensor.graph.replace.graph_replace(hess, {mode: soln}) | 
|  | 149 | +    # else: | 
|  | 150 | +    #     hess = pytensor.gradient.hessian(nll, x) | 
|  | 151 | + | 
|  | 152 | +    # Obtain the gradient and Hessian (re-use graphs if already computed in minimize) | 
|  | 153 | +    # res = soln.owner.op.inner_outputs | 
|  | 154 | +    # mode = res[0] | 
|  | 155 | + | 
|  | 156 | +    # print(res) | 
|  | 157 | + | 
|  | 158 | +    # if use_jac: | 
|  | 159 | +    #     # jac = pytensor.gradient.grad(nll, x) | 
|  | 160 | +    #     jac = res.pop(1) | 
|  | 161 | +    # else: | 
|  | 162 | +    #     jac = pytensor.gradient.grad(nll, x) | 
|  | 163 | +    #     jac = pytensor.graph.replace.graph_replace(jac, {x: soln}) | 
|  | 164 | + | 
|  | 165 | +    # print(x) | 
|  | 166 | +    # # jac = pytensor.graph.replace.graph_replace(jac, {x: soln}) | 
|  | 167 | + | 
|  | 168 | +    # jac = -jac # We subsequently want the gradients wrt log(p(y | x)) rather than the negative of this (nll) | 
|  | 169 | + | 
|  | 170 | +    # if use_hess: | 
|  | 171 | +    #     hess = res.pop(1) | 
|  | 172 | +    # else: | 
|  | 173 | +    #     hess = pytensor.gradient.jacobian(jac.flatten(), soln) | 
|  | 174 | +    #     # hess = pytensor.graph.replace.graph_replace(hess, {x: soln}) | 
|  | 175 | + | 
|  | 176 | +    # args = model.continuous_value_vars + model.discrete_value_vars | 
|  | 177 | +    # return pytensor.function(args, [soln, jac, hess]) | 
|  | 178 | + | 
|  | 179 | + | 
| 58 | 180 | def laplace_draws_to_inferencedata( | 
| 59 | 181 |     posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None | 
| 60 | 182 | ) -> az.InferenceData: | 
| @@ -418,69 +540,6 @@ def sample_laplace_posterior( | 
| 418 | 540 |     return idata | 
| 419 | 541 | 
 | 
| 420 | 542 | 
 | 
| 421 |  | -def find_mode_and_hess( | 
| 422 |  | -    x: TensorVariable, | 
| 423 |  | -    model: pm.Model | None = None, | 
| 424 |  | -    method: minimize_method = "BFGS", | 
| 425 |  | -    use_jac: bool = True, | 
| 426 |  | -    use_hess: bool = False,  # TODO Tbh we can probably just remove this arg and pass True to the minimizer all the time, but if this is the case, it will throw a warning when the hessian doesn't need to be computed for a particular optimisation routine. | 
| 427 |  | -    optimizer_kwargs: dict | None = None, | 
| 428 |  | -) -> Callable: | 
| 429 |  | -    """ | 
| 430 |  | -    Returns a function to estimate the mode and hessian of a model by minimizing negative log likelihood. Wrapper for (pytensor-native) scipy.optimize.minimize. | 
| 431 |  | -
 | 
| 432 |  | -    Parameters | 
| 433 |  | -    ---------- | 
| 434 |  | -    x: TensorVariable | 
| 435 |  | -        The parameter with which to minimize wrt (that is, find the mode in x). | 
| 436 |  | -    model: Model | 
| 437 |  | -        PyMC model to use. | 
| 438 |  | -    method: minimize_method | 
| 439 |  | -        Which minimization algorithm to use. | 
| 440 |  | -    use_jac: bool | 
| 441 |  | -        If true, the minimizer will compute and store the Jacobian. | 
| 442 |  | -    use_hess: bool | 
| 443 |  | -        If true, the minimizer will compute and store the Hessian (note that the Hessian will be computed explicitely even if this is False). | 
| 444 |  | -    optimizer_kwargs: dict | 
| 445 |  | -        Kwargs to pass to scipy.optimize.minimize. | 
| 446 |  | -
 | 
| 447 |  | -    Returns | 
| 448 |  | -    ------- | 
| 449 |  | -    f: Callable | 
| 450 |  | -        A function which accepts the values of the model RVs as args and returns [mu, hess(mu)], where mu is the mode. The TensorVariable x is specified as an initial guess for mu in args. | 
| 451 |  | -    """ | 
| 452 |  | -    model = pm.modelcontext(model) | 
| 453 |  | - | 
| 454 |  | -    # Minimise negative log likelihood | 
| 455 |  | -    nll = -model.logp() | 
| 456 |  | -    soln, _ = minimize( | 
| 457 |  | -        objective=nll, | 
| 458 |  | -        x=x, | 
| 459 |  | -        method=method, | 
| 460 |  | -        jac=use_jac, | 
| 461 |  | -        hess=use_hess, | 
| 462 |  | -        optimizer_kwargs=optimizer_kwargs, | 
| 463 |  | -    ) | 
| 464 |  | - | 
| 465 |  | -    # TODO: Jesse suggested I use this graph_replace function, but it seems that "mode" here is a different type to soln: | 
| 466 |  | -    # | 
| 467 |  | -    # TypeError: Cannot convert Type Vector(float64, shape=(10,)) (of Variable MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0) into Type Scalar(float64, shape=()). You can try to manually convert MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0 into a Scalar(float64, shape=()). | 
| 468 |  | -    # | 
| 469 |  | -    # My understanding here is that for some function which evaluates the hessian at x, we're replacing "x" in the hess graph with the subgraph that computes "x" (i.e. soln)? | 
| 470 |  | - | 
| 471 |  | -    # Obtain the Hessian (re-use graph if already computed in minimize) | 
| 472 |  | -    if use_hess: | 
| 473 |  | -        mode, _, hess = ( | 
| 474 |  | -            soln.owner.op.inner_outputs | 
| 475 |  | -        )  # Note that this mode, _, hess will need to be slightly more elaborate for when use_jac is False (2 items to unpack instead of 3). Just a few if-blocks, but not implemented for now while we're debugging | 
| 476 |  | -        hess = pytensor.graph.replace.graph_replace(hess, {mode: soln}) | 
| 477 |  | -    else: | 
| 478 |  | -        hess = pytensor.gradient.hessian(nll, x) | 
| 479 |  | - | 
| 480 |  | -    args = model.continuous_value_vars + model.discrete_value_vars | 
| 481 |  | -    return pytensor.function(args, [soln, hess]) | 
| 482 |  | - | 
| 483 |  | - | 
| 484 | 543 | def fit_laplace( | 
| 485 | 544 |     optimize_method: minimize_method | Literal["basinhopping"] = "BFGS", | 
| 486 | 545 |     *, | 
|  | 
0 commit comments