@@ -491,124 +491,155 @@ The first column of the Q-table represents the value associated with rejecting t
491491We use `numba` compilation to accelerate computations.
492492
493493```{code-cell} ipython3
494- params=[
495- ('c', float64), # unemployment compensation
496- ('β', float64), # discount factor
497- ('w', float64[:]), # array of wage values, w[i] = wage at state i
498- ('q', float64[:]), # array of probabilities
499- ('eps', float64), # for epsilon greedy algorithm
500- ('δ', float64), # Q-table threshold
501- ('lr', float64), # the learning rate α
502- ('T', int64), # maximum periods of accepting
503- ('quit_allowed', int64) # whether quit is allowed after accepting the wage offer
504- ]
505-
506- @jitclass(params)
507- class Qlearning_McCall:
508- def __init__(self, c=25, β=0.99, w=w_default, q=q_default, eps=0.1,
509- δ=1e-5, lr=0.5, T=10000, quit_allowed=0):
510-
511- self.c, self.β = c, β
512- self.w, self.q = w, q
513- self.eps, self.δ, self.lr, self.T = eps, δ, lr, T
514- self.quit_allowed = quit_allowed
515-
516-
517- def draw_offer_index(self):
518- """
519- Draw a state index from the wage distribution.
520- """
521-
522- q = self.q
523- return np.searchsorted(np.cumsum(q), np.random.random(), side="right")
524-
525- def temp_diff(self, qtable, state, accept):
526- """
527- Compute the TD associated with state and action.
528- """
529-
530- c, β, w = self.c, self.β, self.w
531-
532- if accept==0:
533- state_next = self.draw_offer_index()
534- TD = c + β*np.max(qtable[state_next, :]) - qtable[state, accept]
535- else:
536- state_next = state
537- if self.quit_allowed == 0:
538- TD = w[state_next] + β*np.max(qtable[state_next, :]) - qtable[state, accept]
539- else:
540- TD = w[state_next] + β*qtable[state_next, 1] - qtable[state, accept]
494+ class QlearningMcCall(NamedTuple):
495+ c: float # unemployment compensation
496+ β: float # discount factor
497+ w: jnp.ndarray # array of wage values, w[i] = wage at state i
498+ q: jnp.ndarray # array of probabilities
499+ eps: float # for epsilon greedy algorithm
500+ δ: float # Q-table threshold
501+ lr: float # the learning rate α
502+ T: int # maximum periods of accepting
503+ quit_allowed: int # whether quit is allowed after accepting the wage offer
504+
505+
506+ def create_qlearning_mccall(c=25,
507+ β=0.99,
508+ w=w_default,
509+ q=q_default,
510+ eps=0.1,
511+ δ=1e-5,
512+ lr=0.5,
513+ T=10000,
514+ quit_allowed=0):
515+ return QlearningMcCall(c=c,
516+ β=β,
517+ w=w,
518+ q=q,
519+ eps=eps,
520+ δ=δ,
521+ lr=lr,
522+ T=T,
523+ quit_allowed=quit_allowed)
541524
525+
526+ @jax.jit
527+ def draw_offer_index(model, key):
528+ """
529+ Draw a state index from the wage distribution.
530+ """
531+ q = model.q
532+ random_val = jax.random.uniform(key)
533+ return jnp.searchsorted(jnp.cumsum(q), random_val, side="right")
534+
535+
536+ @jax.jit
537+ def temp_diff(model, qtable, state, accept, key):
538+ """
539+ Compute the TD associated with state and action.
540+ """
541+ c, β, w = model.c, model.β, model.w
542+
543+ def reject_case():
544+ state_next = draw_offer_index(model, key)
545+ TD = c + β * jnp.max(qtable[state_next, :]) - qtable[state, accept]
546+ return TD, state_next
547+
548+ def accept_case():
549+ state_next = state
550+ TD = jnp.where(model.quit_allowed == 0,
551+ w[state_next] + β * jnp.max(qtable[state_next, :]) - qtable[state, accept],
552+ w[state_next] + β * qtable[state_next, 1] - qtable[state, accept])
542553 return TD, state_next
543554
544- def run_one_epoch(self, qtable, max_times=20000):
545- """
546- Run an "epoch".
547- """
555+ return jax.lax.cond(accept == 0, reject_case, accept_case)
556+
557+
558+ @jax.jit
559+ def run_one_epoch(model, qtable, key, max_times=20000):
560+ """Run an "epoch"."""
561+ eps, δ, lr, T = model.eps, model.δ, model.lr, model.T
562+
563+ # Split keys for multiple random operations
564+ key, subkey1, subkey2 = jax.random.split(key, 3)
565+
566+ # Initial state
567+ s0 = draw_offer_index(model, subkey1)
548568
549- c, β, w = self.c, self.β, self.w
550- eps, δ, lr, T = self.eps, self.δ, self.lr, self.T
569+ def body_fun(state):
570+ qtable, s, accept_count, t, key = state
551571
552- s0 = self.draw_offer_index()
553- s = s0
554- accept_count = 0
572+ # Split key for this iteration's random operations
573+ key, action_key, td_key = jax.random.split(key, 3)
555574
556- for t in range(max_times):
575+ # Choose action (epsilon-greedy)
576+ accept = jnp.argmax(qtable[s, :])
577+ random_val = jax.random.uniform(action_key)
578+ accept = jnp.where(random_val <= eps, 1 - accept, accept)
557579
558- # choose action
559- accept = np.argmax(qtable[s, :])
560- if np.random.random()<=eps:
561- accept = 1 - accept
580+ # Update accept count
581+ accept_count = jnp.where(accept == 1, accept_count + 1, 0)
562582
563- if accept == 1:
564- accept_count += 1
565- else:
566- accept_count = 0
583+ # Compute temporal difference
584+ TD, s_next = temp_diff(model, qtable, s, accept, td_key)
567585
568- TD, s_next = self.temp_diff(qtable, s, accept)
586+ # Update qtable
587+ qtable_new = qtable.at[s, accept].add(lr * TD)
569588
570- # update qtable
571- qtable_new = qtable.copy()
572- qtable_new[s, accept] = qtable[s, accept] + lr*TD
589+ # Calculate error
590+ error = jnp.max(jnp.abs(qtable_new - qtable))
573591
574- if np.max(np.abs(qtable_new-qtable))<=δ:
575- break
592+ return qtable_new, s_next, accept_count, t + 1, key
576593
577- if accept_count == T:
578- break
594+ def cond_fun(state):
595+ qtable, s, accept_count, t, key = state
596+ # for first interaction, just continue since error is large
597+ # for subsequent interactions, compute actual error
598+ error = jnp.where(t==0, δ + 1, jnp.max(jnp.abs(qtable - state[0])))
579599
580- s, qtable = s_next, qtable_new
600+ continue_condition = (error > δ) & (accept_count < T) & (t < max_times)
601+ return continue_condition
581602
582- return qtable_new
603+ # Initial state: (qtable, state, accept_count, iteration, key)
604+ init_state = (qtable, s0, 0, 0, subkey2)
605+ final_qtable, final_s, final_accept_count, final_t, final_key = jax.lax.while_loop(
606+ cond_fun, body_fun, init_state
607+ )
583608
584- @jit
585- def run_epochs(N, qlmc, qtable):
609+ return final_qtable
610+
611+
612+ def run_epochs(N, qlmc, qtable, key):
586613 """
587614 Run epochs N times with qtable from the last iteration each time.
588615 """
589-
590616 for n in range(N):
591- if n%(N/ 10)== 0:
617+ if n % (N // 10) == 0:
592618 print(f"Progress: EPOCHs = {n}")
593- new_qtable = qlmc.run_one_epoch(qtable)
594- qtable = new_qtable
619+
620+ # Split key for this epoch
621+ key, subkey = jax.random.split(key)
622+ qtable = run_one_epoch(qlmc, qtable, subkey)
595623
596624 return qtable
597625
626+
598627def valfunc_from_qtable(qtable):
599- return np.max(qtable, axis=1)
628+ return jnp.max(qtable, axis=1)
629+
600630
601631def compute_error(valfunc, valfunc_VFI):
602- return np .mean(np .abs(valfunc- valfunc_VFI))
632+ return jnp .mean(jnp .abs(valfunc - valfunc_VFI))
603633```
604634
605635```{code-cell} ipython3
606636# create an instance of Qlearning_McCall
607- qlmc = Qlearning_McCall ()
637+ qlmc = create_qlearning_mccall ()
608638
609639# run
610- qtable0 = np.zeros((len(w_default), 2))
611- qtable = run_epochs(20000, qlmc, qtable0)
640+ qtable0 = jnp.zeros((len(w_default), 2))
641+ key, subkey = jax.random.split(key)
642+ qtable = run_epochs(20000, qlmc, qtable0, subkey)
612643```
613644
614645```{code-cell} ipython3
@@ -641,7 +672,7 @@ n, a, b = 30, 200, 100 # default parameters
641672q_new = BetaBinomial(n, a, b).pdf() # default choice of q
642673
643674w_min, w_max = 10, 60
644- w_new = np .linspace(w_min, w_max, n+1)
675+ w_new = jnp .linspace(w_min, w_max, n+1)
645676
646677
647678# plot distribution of wage offer
@@ -651,47 +682,47 @@ ax.set_xlabel('wages')
651682ax.set_ylabel('probabilities')
652683
653684plt.show()
654-
655- # VFI
656- mcm = McCallModel(w=w_new, q=q_new)
657- valfunc_VFI, flag = mcm.VFI()
658685```
659686
660687```{code-cell} ipython3
661- mcm = McCallModel(w=w_new, q=q_new)
662- valfunc_VFI, flag = mcm.VFI()
688+ # VFI
689+ mcm = create_mccall_model(w=w_new, q=q_new)
690+ valfunc_VFI, flag = VFI(mcm)
663691valfunc_VFI
664692```
665693
666694```{code-cell} ipython3
667- def plot_epochs(epochs_to_plot, quit_allowed=1):
668- "Plot value function implied by outcomes of an increasing number of epochs."
669- qlmc_new = Qlearning_McCall(w=w_new, q=q_new, quit_allowed=quit_allowed)
670- qtable = np.zeros((len(w_new),2))
671- epochs_to_plot = np.asarray(epochs_to_plot)
672- # plot
673- fig, ax = plt.subplots(figsize=(10,6))
674- ax.plot(w_new, valfunc_VFI, '-o', label='VFI')
675-
676- max_epochs = np.max(epochs_to_plot)
677- # iterate on epoch numbers
678- for n in range(max_epochs + 1):
679- if n%(max_epochs/10)==0:
680- print(f"Progress: EPOCHs = {n}")
681- if n in epochs_to_plot:
682- valfunc_qlr = valfunc_from_qtable(qtable)
683- error = compute_error(valfunc_qlr, valfunc_VFI)
684-
685- ax.plot(w_new, valfunc_qlr, '-o', label=f'QL:epochs={n}, mean error={error}')
686-
687-
688- new_qtable = qlmc_new.run_one_epoch(qtable)
689- qtable = new_qtable
690-
691- ax.set_xlabel('wages')
692- ax.set_ylabel('optimal value')
693- ax.legend(loc='lower right')
694- plt.show()
695+ def plot_epochs(epochs_to_plot, quit_allowed=1, key=None):
696+ "Plot value function implied by outcomes of an increasing number of epochs."
697+ if key is None:
698+ key = jax.random.PRNGKey(42) # Default key if none provided
699+
700+ qlmc_new = create_qlearning_mccall(w=w_new, q=q_new, quit_allowed=quit_allowed)
701+ qtable = jnp.zeros((len(w_new),2))
702+ epochs_to_plot = jnp.asarray(epochs_to_plot)
703+ # plot
704+ fig, ax = plt.subplots(figsize=(10,6))
705+ ax.plot(w_new, valfunc_VFI, '-o', label='VFI')
706+
707+ max_epochs = int(jnp.max(epochs_to_plot)) # Convert to Python int
708+ # iterate on epoch numbers
709+ for n in range(max_epochs + 1):
710+ if n%(max_epochs/10)==0:
711+ print(f"Progress: EPOCHs = {n}")
712+ if n in epochs_to_plot:
713+ valfunc_qlr = valfunc_from_qtable(qtable)
714+ error = compute_error(valfunc_qlr, valfunc_VFI)
715+
716+ ax.plot(w_new, valfunc_qlr, '-o', label=f'QL:epochs={n}, mean error={error}')
717+
718+ # Split key for this epoch
719+ key, subkey = jax.random.split(key)
720+ qtable = run_one_epoch(qlmc_new, qtable, subkey)
721+
722+ ax.set_xlabel('wages')
723+ ax.set_ylabel('optimal value')
724+ ax.legend(loc='lower right')
725+ plt.show()
695726```
696727
697728```{code-cell} ipython3
0 commit comments