diff --git a/nbs/03_fit.ipynb b/nbs/03_fit.ipynb index 4526d62..6366ee8 100644 --- a/nbs/03_fit.ipynb +++ b/nbs/03_fit.ipynb @@ -38,7 +38,7 @@ "# export\n", "import jax\n", "from fax.implicit import twophase\n", - "import jax.experimental.optimizers as optimizers\n", + "from jax.experimental import optix\n", "\n", "from neos.transforms import to_bounded_vec, to_inf_vec, to_bounded, to_inf\n", "from neos.models import *" @@ -61,7 +61,7 @@ "):\n", " '''\n", " Wraps a series of functions that perform maximum likelihood fitting in the \n", - " `two_phase_solver` method found in the `fax` python module. This allows for\n", + " `two_phase_solve` method found in the `fax` python module. This allows for\n", " the calculation of gradients of the best-fit parameters with respect to upstream\n", " parameters that control the underlying model, i.e. the event yields (which are \n", " then parameterized by weights or similar).\n", @@ -74,7 +74,7 @@ " respectively. Differentiable :)\n", " '''\n", "\n", - " adam_init, adam_update, adam_get_params = optimizers.adam(1e-6)\n", + " gradient_descent = optix.scale(-1e-2)\n", "\n", " def make_model(hyper_pars):\n", " constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1]\n", @@ -108,51 +108,58 @@ " )\n", " return -expected_logpdf(pars)[0]\n", "\n", - " return constrained_mu, global_fit_objective, constrained_fit_objective,bounds\n", + " return constrained_mu, global_fit_objective, constrained_fit_objective, bounds\n", "\n", " def global_bestfit_minimized(hyper_param):\n", " _, nll, _ ,_ = make_model(hyper_param)\n", "\n", - " def bestfit_via_grad_descent(i, param): # gradient descent\n", + " def bestfit_via_grad_descent(param): # gradient descent\n", " g = jax.grad(nll)(param)\n", - " # param = param - g * learning_rate\n", - " param = adam_get_params(adam_update(i,g,adam_init(param)))\n", - " return param\n", + " updates, _ = gradient_descent.update(g, gradient_descent.init(param))\n", + " return optix.apply_updates(param, updates)\n", "\n", " return bestfit_via_grad_descent\n", + " \n", "\n", " def constrained_bestfit_minimized(hyper_param):\n", - " mu, nll, cnll,bounds = make_model(hyper_param)\n", + " mu, nll, cnll, bounds = make_model(hyper_param)\n", "\n", - " def bestfit_via_grad_descent(i, param): # gradient descent\n", + " def bestfit_via_grad_descent(param): # gradient descent\n", " _, np = param[0], param[1:]\n", " g = jax.grad(cnll)(np)\n", - " np = adam_get_params(adam_update(i,g,adam_init(np)))\n", + " updates, _ = gradient_descent.update(g, gradient_descent.init(np))\n", + " np = optix.apply_updates(np, updates)\n", " param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np])\n", " return param\n", + " \n", "\n", " return bestfit_via_grad_descent\n", - "\n", - " global_solve = twophase.two_phase_solver(\n", - " param_func=global_bestfit_minimized,\n", - " default_rtol=default_rtol,\n", - " default_atol=default_atol,\n", - " default_max_iter=default_max_iter\n", + " \n", + " convergence_test = twophase.default_convergence_test(\n", + " rtol=default_rtol,\n", + " atol=default_atol,\n", " )\n", - " constrained_solver = twophase.two_phase_solver(\n", - " param_func=constrained_bestfit_minimized,\n", - " default_rtol=default_rtol,\n", - " default_atol=default_atol,\n", - " default_max_iter=default_max_iter,\n", + " global_solver = twophase.default_solver(\n", + " convergence_test=convergence_test,\n", + " max_iter=default_max_iter,\n", " )\n", + " constrained_solver = global_solver\n", "\n", " def g_fitter(init, hyper_pars):\n", - " solve = global_solve(init, hyper_pars)\n", - " return solve.value\n", + " return twophase.two_phase_solve(\n", + " global_bestfit_minimized,\n", + " init,\n", + " hyper_pars,\n", + " solvers=(global_solver,),\n", + " )\n", "\n", " def c_fitter(init, hyper_pars):\n", - " solve = constrained_solver(init, hyper_pars)\n", - " return solve.value\n", + " return twophase.two_phase_solve(\n", + " constrained_bestfit_minimized,\n", + " init,\n", + " hyper_pars,\n", + " solvers=(constrained_solver,),\n", + " )\n", "\n", " return g_fitter, c_fitter" ] diff --git a/neos/fit.py b/neos/fit.py index 29ae0ad..92b1865 100644 --- a/neos/fit.py +++ b/neos/fit.py @@ -5,7 +5,7 @@ # Cell import jax from fax.implicit import twophase -import jax.experimental.optimizers as optimizers +from jax.experimental import optix from .transforms import to_bounded_vec, to_inf_vec, to_bounded, to_inf from .models import * @@ -21,7 +21,7 @@ def get_solvers( ): ''' Wraps a series of functions that perform maximum likelihood fitting in the - `two_phase_solver` method found in the `fax` python module. This allows for + `two_phase_solve` method found in the `fax` python module. This allows for the calculation of gradients of the best-fit parameters with respect to upstream parameters that control the underlying model, i.e. the event yields (which are then parameterized by weights or similar). @@ -34,7 +34,7 @@ def get_solvers( respectively. Differentiable :) ''' - adam_init, adam_update, adam_get_params = optimizers.adam(1e-6) + gradient_descent = optix.scale(-1e-2) def make_model(hyper_pars): constrained_mu, nn_pars = hyper_pars[0], hyper_pars[1] @@ -68,50 +68,57 @@ def constrained_fit_objective(nuis_par): # NLL ) return -expected_logpdf(pars)[0] - return constrained_mu, global_fit_objective, constrained_fit_objective,bounds + return constrained_mu, global_fit_objective, constrained_fit_objective, bounds def global_bestfit_minimized(hyper_param): _, nll, _ ,_ = make_model(hyper_param) - def bestfit_via_grad_descent(i, param): # gradient descent + def bestfit_via_grad_descent(param): # gradient descent g = jax.grad(nll)(param) - # param = param - g * learning_rate - param = adam_get_params(adam_update(i,g,adam_init(param))) - return param + updates, _ = gradient_descent.update(g, gradient_descent.init(param)) + return optix.apply_updates(param, updates) return bestfit_via_grad_descent + def constrained_bestfit_minimized(hyper_param): - mu, nll, cnll,bounds = make_model(hyper_param) + mu, nll, cnll, bounds = make_model(hyper_param) - def bestfit_via_grad_descent(i, param): # gradient descent + def bestfit_via_grad_descent(param): # gradient descent _, np = param[0], param[1:] g = jax.grad(cnll)(np) - np = adam_get_params(adam_update(i,g,adam_init(np))) + updates, _ = gradient_descent.update(g, gradient_descent.init(np)) + np = optix.apply_updates(np, updates) param = jax.numpy.concatenate([jax.numpy.asarray([mu]), np]) return param + return bestfit_via_grad_descent - global_solve = twophase.two_phase_solver( - param_func=global_bestfit_minimized, - default_rtol=default_rtol, - default_atol=default_atol, - default_max_iter=default_max_iter + convergence_test = twophase.default_convergence_test( + rtol=default_rtol, + atol=default_atol, ) - constrained_solver = twophase.two_phase_solver( - param_func=constrained_bestfit_minimized, - default_rtol=default_rtol, - default_atol=default_atol, - default_max_iter=default_max_iter, + global_solver = twophase.default_solver( + convergence_test=convergence_test, + max_iter=default_max_iter, ) + constrained_solver = global_solver def g_fitter(init, hyper_pars): - solve = global_solve(init, hyper_pars) - return solve.value + return twophase.two_phase_solve( + global_bestfit_minimized, + init, + hyper_pars, + solvers=(global_solver,), + ) def c_fitter(init, hyper_pars): - solve = constrained_solver(init, hyper_pars) - return solve.value + return twophase.two_phase_solve( + constrained_bestfit_minimized, + init, + hyper_pars, + solvers=(constrained_solver,), + ) return g_fitter, c_fitter \ No newline at end of file