-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Hi! Thanks for this implementation!
I am trying to use this implementation but I am running into the error in the title of this issue. Here is what I am working with:
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
from typing import Callable, Tuple
import haiku as hk
from jax.random import PRNGKey, split
import optax
import matplotlib.pyplot as plt
from optax_swag import swag
def nll(apply_fn: Callable):
def _nll(params, batch: Tuple[jax.Array, jax.Array]) -> float:
x, y = batch
out = apply_fn(params, x)
ll = stats.norm.logpdf(out, y)
return - ll.sum()
return _nll
def generate_data():
x = jnp.linspace(0, 10, 25).reshape(-1, 1)
y = jnp.sin(0.4 * x) + 3
return x, y
def make_small_mlp():
relu = jax.nn.relu
def small_mlp(x):
mlp = hk.Sequential([
hk.Linear(50),
relu,
hk.Linear(50),
relu,
hk.Linear(50),
relu,
hk.Linear(1)])
return mlp(x)
return hk.transform(small_mlp)
def train_model(params, model_apply, data, opt_init, opt_update, epochs, loss_fn):
loss_fn = nll(model_apply)
x, y = data
opt_state = opt_init(params)
print(opt_state)
@jax.jit
def train_one_epoch(params, opt_state):
nll_val, grad = jax.value_and_grad(loss_fn)(params, (x,y))
updates, opt_state = opt_update(grad, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, nll_val
for i in range(epochs):
params, opt_state, nll_val = train_one_epoch(params, opt_state)
print(f"STEP {i} | NLL: {nll_val}")
preds = model_apply(params, x)
plt.plot(x, y)
plt.plot(x, preds)
plt.show()
return params
model_init_key, _, _, _, _ = split(PRNGKey(123), 5)
x,y = generate_data()
mlp = make_small_mlp()
params = mlp.init(model_init_key, x[0])
model_apply = lambda params, x: mlp.apply(params, None, x)
opt_init, opt_update = optax.chain(optax.adam(1e-3), swag(5, 5))
params = train_model(params, model_apply, (x,y), opt_init, opt_update, 500, nll)
# 'ValueError: Expected dict, got None.'Any ideas of what I may doing wrong? Thanks so much!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels