Skip to content

Commit c8814f3

Browse files
committed
complete jax conversion of q-learning implementation
1 parent 61004bc commit c8814f3

File tree

1 file changed

+152
-121
lines changed

1 file changed

+152
-121
lines changed

lectures/mccall_q.md

Lines changed: 152 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -491,124 +491,155 @@ The first column of the Q-table represents the value associated with rejecting t
491491
We use `numba` compilation to accelerate computations.
492492
493493
```{code-cell} ipython3
494-
params=[
495-
('c', float64), # unemployment compensation
496-
('β', float64), # discount factor
497-
('w', float64[:]), # array of wage values, w[i] = wage at state i
498-
('q', float64[:]), # array of probabilities
499-
('eps', float64), # for epsilon greedy algorithm
500-
('δ', float64), # Q-table threshold
501-
('lr', float64), # the learning rate α
502-
('T', int64), # maximum periods of accepting
503-
('quit_allowed', int64) # whether quit is allowed after accepting the wage offer
504-
]
505-
506-
@jitclass(params)
507-
class Qlearning_McCall:
508-
def __init__(self, c=25, β=0.99, w=w_default, q=q_default, eps=0.1,
509-
δ=1e-5, lr=0.5, T=10000, quit_allowed=0):
510-
511-
self.c, self.β = c, β
512-
self.w, self.q = w, q
513-
self.eps, self.δ, self.lr, self.T = eps, δ, lr, T
514-
self.quit_allowed = quit_allowed
515-
516-
517-
def draw_offer_index(self):
518-
"""
519-
Draw a state index from the wage distribution.
520-
"""
521-
522-
q = self.q
523-
return np.searchsorted(np.cumsum(q), np.random.random(), side="right")
524-
525-
def temp_diff(self, qtable, state, accept):
526-
"""
527-
Compute the TD associated with state and action.
528-
"""
529-
530-
c, β, w = self.c, self.β, self.w
531-
532-
if accept==0:
533-
state_next = self.draw_offer_index()
534-
TD = c + β*np.max(qtable[state_next, :]) - qtable[state, accept]
535-
else:
536-
state_next = state
537-
if self.quit_allowed == 0:
538-
TD = w[state_next] + β*np.max(qtable[state_next, :]) - qtable[state, accept]
539-
else:
540-
TD = w[state_next] + β*qtable[state_next, 1] - qtable[state, accept]
494+
class QlearningMcCall(NamedTuple):
495+
c: float # unemployment compensation
496+
β: float # discount factor
497+
w: jnp.ndarray # array of wage values, w[i] = wage at state i
498+
q: jnp.ndarray # array of probabilities
499+
eps: float # for epsilon greedy algorithm
500+
δ: float # Q-table threshold
501+
lr: float # the learning rate α
502+
T: int # maximum periods of accepting
503+
quit_allowed: int # whether quit is allowed after accepting the wage offer
504+
505+
506+
def create_qlearning_mccall(c=25,
507+
β=0.99,
508+
w=w_default,
509+
q=q_default,
510+
eps=0.1,
511+
δ=1e-5,
512+
lr=0.5,
513+
T=10000,
514+
quit_allowed=0):
515+
return QlearningMcCall(c=c,
516+
β=β,
517+
w=w,
518+
q=q,
519+
eps=eps,
520+
δ=δ,
521+
lr=lr,
522+
T=T,
523+
quit_allowed=quit_allowed)
541524
525+
526+
@jax.jit
527+
def draw_offer_index(model, key):
528+
"""
529+
Draw a state index from the wage distribution.
530+
"""
531+
q = model.q
532+
random_val = jax.random.uniform(key)
533+
return jnp.searchsorted(jnp.cumsum(q), random_val, side="right")
534+
535+
536+
@jax.jit
537+
def temp_diff(model, qtable, state, accept, key):
538+
"""
539+
Compute the TD associated with state and action.
540+
"""
541+
c, β, w = model.c, model.β, model.w
542+
543+
def reject_case():
544+
state_next = draw_offer_index(model, key)
545+
TD = c + β * jnp.max(qtable[state_next, :]) - qtable[state, accept]
546+
return TD, state_next
547+
548+
def accept_case():
549+
state_next = state
550+
TD = jnp.where(model.quit_allowed == 0,
551+
w[state_next] + β * jnp.max(qtable[state_next, :]) - qtable[state, accept],
552+
w[state_next] + β * qtable[state_next, 1] - qtable[state, accept])
542553
return TD, state_next
543554
544-
def run_one_epoch(self, qtable, max_times=20000):
545-
"""
546-
Run an "epoch".
547-
"""
555+
return jax.lax.cond(accept == 0, reject_case, accept_case)
556+
557+
558+
@jax.jit
559+
def run_one_epoch(model, qtable, key, max_times=20000):
560+
"""Run an "epoch"."""
561+
eps, δ, lr, T = model.eps, model.δ, model.lr, model.T
562+
563+
# Split keys for multiple random operations
564+
key, subkey1, subkey2 = jax.random.split(key, 3)
565+
566+
# Initial state
567+
s0 = draw_offer_index(model, subkey1)
548568
549-
c, β, w = self.c, self.β, self.w
550-
eps, δ, lr, T = self.eps, self.δ, self.lr, self.T
569+
def body_fun(state):
570+
qtable, s, accept_count, t, key = state
551571
552-
s0 = self.draw_offer_index()
553-
s = s0
554-
accept_count = 0
572+
# Split key for this iteration's random operations
573+
key, action_key, td_key = jax.random.split(key, 3)
555574
556-
for t in range(max_times):
575+
# Choose action (epsilon-greedy)
576+
accept = jnp.argmax(qtable[s, :])
577+
random_val = jax.random.uniform(action_key)
578+
accept = jnp.where(random_val <= eps, 1 - accept, accept)
557579
558-
# choose action
559-
accept = np.argmax(qtable[s, :])
560-
if np.random.random()<=eps:
561-
accept = 1 - accept
580+
# Update accept count
581+
accept_count = jnp.where(accept == 1, accept_count + 1, 0)
562582
563-
if accept == 1:
564-
accept_count += 1
565-
else:
566-
accept_count = 0
583+
# Compute temporal difference
584+
TD, s_next = temp_diff(model, qtable, s, accept, td_key)
567585
568-
TD, s_next = self.temp_diff(qtable, s, accept)
586+
# Update qtable
587+
qtable_new = qtable.at[s, accept].add(lr * TD)
569588
570-
# update qtable
571-
qtable_new = qtable.copy()
572-
qtable_new[s, accept] = qtable[s, accept] + lr*TD
589+
# Calculate error
590+
error = jnp.max(jnp.abs(qtable_new - qtable))
573591
574-
if np.max(np.abs(qtable_new-qtable))<=δ:
575-
break
592+
return qtable_new, s_next, accept_count, t + 1, key
576593
577-
if accept_count == T:
578-
break
594+
def cond_fun(state):
595+
qtable, s, accept_count, t, key = state
596+
# for first interaction, just continue since error is large
597+
# for subsequent interactions, compute actual error
598+
error = jnp.where(t==0, δ + 1, jnp.max(jnp.abs(qtable - state[0])))
579599
580-
s, qtable = s_next, qtable_new
600+
continue_condition = (error > δ) & (accept_count < T) & (t < max_times)
601+
return continue_condition
581602
582-
return qtable_new
603+
# Initial state: (qtable, state, accept_count, iteration, key)
604+
init_state = (qtable, s0, 0, 0, subkey2)
605+
final_qtable, final_s, final_accept_count, final_t, final_key = jax.lax.while_loop(
606+
cond_fun, body_fun, init_state
607+
)
583608
584-
@jit
585-
def run_epochs(N, qlmc, qtable):
609+
return final_qtable
610+
611+
612+
def run_epochs(N, qlmc, qtable, key):
586613
"""
587614
Run epochs N times with qtable from the last iteration each time.
588615
"""
589-
590616
for n in range(N):
591-
if n%(N/10)==0:
617+
if n % (N // 10) == 0:
592618
print(f"Progress: EPOCHs = {n}")
593-
new_qtable = qlmc.run_one_epoch(qtable)
594-
qtable = new_qtable
619+
620+
# Split key for this epoch
621+
key, subkey = jax.random.split(key)
622+
qtable = run_one_epoch(qlmc, qtable, subkey)
595623
596624
return qtable
597625
626+
598627
def valfunc_from_qtable(qtable):
599-
return np.max(qtable, axis=1)
628+
return jnp.max(qtable, axis=1)
629+
600630
601631
def compute_error(valfunc, valfunc_VFI):
602-
return np.mean(np.abs(valfunc-valfunc_VFI))
632+
return jnp.mean(jnp.abs(valfunc - valfunc_VFI))
603633
```
604634
605635
```{code-cell} ipython3
606636
# create an instance of Qlearning_McCall
607-
qlmc = Qlearning_McCall()
637+
qlmc = create_qlearning_mccall()
608638
609639
# run
610-
qtable0 = np.zeros((len(w_default), 2))
611-
qtable = run_epochs(20000, qlmc, qtable0)
640+
qtable0 = jnp.zeros((len(w_default), 2))
641+
key, subkey = jax.random.split(key)
642+
qtable = run_epochs(20000, qlmc, qtable0, subkey)
612643
```
613644
614645
```{code-cell} ipython3
@@ -641,7 +672,7 @@ n, a, b = 30, 200, 100 # default parameters
641672
q_new = BetaBinomial(n, a, b).pdf() # default choice of q
642673
643674
w_min, w_max = 10, 60
644-
w_new = np.linspace(w_min, w_max, n+1)
675+
w_new = jnp.linspace(w_min, w_max, n+1)
645676
646677
647678
# plot distribution of wage offer
@@ -651,47 +682,47 @@ ax.set_xlabel('wages')
651682
ax.set_ylabel('probabilities')
652683
653684
plt.show()
654-
655-
# VFI
656-
mcm = McCallModel(w=w_new, q=q_new)
657-
valfunc_VFI, flag = mcm.VFI()
658685
```
659686
660687
```{code-cell} ipython3
661-
mcm = McCallModel(w=w_new, q=q_new)
662-
valfunc_VFI, flag = mcm.VFI()
688+
# VFI
689+
mcm = create_mccall_model(w=w_new, q=q_new)
690+
valfunc_VFI, flag = VFI(mcm)
663691
valfunc_VFI
664692
```
665693
666694
```{code-cell} ipython3
667-
def plot_epochs(epochs_to_plot, quit_allowed=1):
668-
"Plot value function implied by outcomes of an increasing number of epochs."
669-
qlmc_new = Qlearning_McCall(w=w_new, q=q_new, quit_allowed=quit_allowed)
670-
qtable = np.zeros((len(w_new),2))
671-
epochs_to_plot = np.asarray(epochs_to_plot)
672-
# plot
673-
fig, ax = plt.subplots(figsize=(10,6))
674-
ax.plot(w_new, valfunc_VFI, '-o', label='VFI')
675-
676-
max_epochs = np.max(epochs_to_plot)
677-
# iterate on epoch numbers
678-
for n in range(max_epochs + 1):
679-
if n%(max_epochs/10)==0:
680-
print(f"Progress: EPOCHs = {n}")
681-
if n in epochs_to_plot:
682-
valfunc_qlr = valfunc_from_qtable(qtable)
683-
error = compute_error(valfunc_qlr, valfunc_VFI)
684-
685-
ax.plot(w_new, valfunc_qlr, '-o', label=f'QL:epochs={n}, mean error={error}')
686-
687-
688-
new_qtable = qlmc_new.run_one_epoch(qtable)
689-
qtable = new_qtable
690-
691-
ax.set_xlabel('wages')
692-
ax.set_ylabel('optimal value')
693-
ax.legend(loc='lower right')
694-
plt.show()
695+
def plot_epochs(epochs_to_plot, quit_allowed=1, key=None):
696+
"Plot value function implied by outcomes of an increasing number of epochs."
697+
if key is None:
698+
key = jax.random.PRNGKey(42) # Default key if none provided
699+
700+
qlmc_new = create_qlearning_mccall(w=w_new, q=q_new, quit_allowed=quit_allowed)
701+
qtable = jnp.zeros((len(w_new),2))
702+
epochs_to_plot = jnp.asarray(epochs_to_plot)
703+
# plot
704+
fig, ax = plt.subplots(figsize=(10,6))
705+
ax.plot(w_new, valfunc_VFI, '-o', label='VFI')
706+
707+
max_epochs = int(jnp.max(epochs_to_plot)) # Convert to Python int
708+
# iterate on epoch numbers
709+
for n in range(max_epochs + 1):
710+
if n%(max_epochs/10)==0:
711+
print(f"Progress: EPOCHs = {n}")
712+
if n in epochs_to_plot:
713+
valfunc_qlr = valfunc_from_qtable(qtable)
714+
error = compute_error(valfunc_qlr, valfunc_VFI)
715+
716+
ax.plot(w_new, valfunc_qlr, '-o', label=f'QL:epochs={n}, mean error={error}')
717+
718+
# Split key for this epoch
719+
key, subkey = jax.random.split(key)
720+
qtable = run_one_epoch(qlmc_new, qtable, subkey)
721+
722+
ax.set_xlabel('wages')
723+
ax.set_ylabel('optimal value')
724+
ax.legend(loc='lower right')
725+
plt.show()
695726
```
696727
697728
```{code-cell} ipython3

0 commit comments

Comments
 (0)