@@ -495,37 +495,16 @@ The first column of the Q-table represents the value associated with rejecting t
495495We use JAX compilation to accelerate computations.
496496
497497```{code-cell} ipython3
498- class QlearningMcCall(NamedTuple):
499- c: float # unemployment compensation
500- β: float # discount factor
501- w: jnp.ndarray # array of wage values, w[i] = wage at state i
502- q: jnp.ndarray # array of probabilities
503- ε: float # for ε greedy algorithm
504- δ: float # Q-table threshold
505- lr: float # the learning rate α
506- T: int # maximum periods of accepting
507- quit_allowed: int # whether quit is allowed after accepting the wage offer
508-
509-
510- def create_qlearning_mccall(c=25,
511- β=0.99,
512- w=w_default,
513- q=q_default,
514- ε=0.1,
515- δ=1e-5,
516- lr=0.5,
517- T=10000,
518- quit_allowed=0):
519- return QlearningMcCall(c=c,
520- β=β,
521- w=w,
522- q=q,
523- ε=ε,
524- δ=δ,
525- lr=lr,
526- T=T,
527- quit_allowed=quit_allowed)
528-
498+ class QLearningMcCall(NamedTuple):
499+ c: float = 25 # unemployment compensation
500+ β: float = 0.99 # discount factor
501+ w: jnp.ndarray = w_default # array of wage values, w[i] = wage at state i
502+ q: jnp.ndarray = q_default # array of probabilities
503+ ε: float = 0.1 # for ε greedy algorithm
504+ δ: float = 1e-5 # Q-table threshold
505+ lr: float = 0.5 # the learning rate α
506+ T: int = 10000 # maximum periods of accepting
507+ quit_allowed: int = 0 # whether quit is allowed after accepting the wage offer
529508
530509@jax.jit
531510def draw_offer_index(model, key):
@@ -546,15 +525,15 @@ def temp_diff(model, qtable, state, accept, key):
546525
547526 def reject_case():
548527 state_next = draw_offer_index(model, key)
549- TD = c + β * jnp.max(qtable[state_next, :]) - qtable[state, accept]
550- return TD , state_next
528+ td = c + β * jnp.max(qtable[state_next, :]) - qtable[state, accept]
529+ return td , state_next
551530
552531 def accept_case():
553532 state_next = state
554- TD = jnp.where(model.quit_allowed == 0,
533+ td = jnp.where(model.quit_allowed == 0,
555534 w[state_next] + β * jnp.max(qtable[state_next, :]) - qtable[state, accept],
556535 w[state_next] + β * qtable[state_next, 1] - qtable[state, accept])
557- return TD , state_next
536+ return td , state_next
558537
559538 return jax.lax.cond(accept == 0, reject_case, accept_case)
560539
@@ -585,10 +564,10 @@ def run_one_epoch(model, qtable, key, max_times=20000):
585564 accept_count = jnp.where(accept == 1, accept_count + 1, 0)
586565
587566 # Compute temporal difference
588- TD , s_next = temp_diff(model, qtable, s, accept, td_key)
567+ td , s_next = temp_diff(model, qtable, s, accept, td_key)
589568
590569 # Update qtable
591- qtable_new = qtable.at[s, accept].add(lr * TD )
570+ qtable_new = qtable.at[s, accept].add(lr * td )
592571
593572 # Calculate error
594573 error = jnp.max(jnp.abs(qtable_new - qtable))
@@ -637,7 +616,7 @@ def compute_error(valfunc, valfunc_VFI):
637616```
638617
639618```{code-cell} ipython3
640- # create an instance of Qlearning_McCall
619+ # create an instance of QLearningMcCall
641620qlmc = create_qlearning_mccall()
642621
643622# run
@@ -689,7 +668,7 @@ plt.show()
689668```
690669
691670```{code-cell} ipython3
692- # VFI
671+ # vfi
693672mcm = create_mccall_model(w=w_new, q=q_new)
694673valfunc_VFI, converged = vfi(mcm)
695674valfunc_VFI
0 commit comments