Skip to content

Commit 89d1b09

Browse files
committed
put default values of variables into the class
1 parent f53c4ef commit 89d1b09

File tree

1 file changed

+18
-39
lines changed

1 file changed

+18
-39
lines changed

lectures/mccall_q.md

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -495,37 +495,16 @@ The first column of the Q-table represents the value associated with rejecting t
495495
We use JAX compilation to accelerate computations.
496496
497497
```{code-cell} ipython3
498-
class QlearningMcCall(NamedTuple):
499-
c: float # unemployment compensation
500-
β: float # discount factor
501-
w: jnp.ndarray # array of wage values, w[i] = wage at state i
502-
q: jnp.ndarray # array of probabilities
503-
ε: float # for ε greedy algorithm
504-
δ: float # Q-table threshold
505-
lr: float # the learning rate α
506-
T: int # maximum periods of accepting
507-
quit_allowed: int # whether quit is allowed after accepting the wage offer
508-
509-
510-
def create_qlearning_mccall(c=25,
511-
β=0.99,
512-
w=w_default,
513-
q=q_default,
514-
ε=0.1,
515-
δ=1e-5,
516-
lr=0.5,
517-
T=10000,
518-
quit_allowed=0):
519-
return QlearningMcCall(c=c,
520-
β=β,
521-
w=w,
522-
q=q,
523-
ε=ε,
524-
δ=δ,
525-
lr=lr,
526-
T=T,
527-
quit_allowed=quit_allowed)
528-
498+
class QLearningMcCall(NamedTuple):
499+
c: float = 25 # unemployment compensation
500+
β: float = 0.99 # discount factor
501+
w: jnp.ndarray = w_default # array of wage values, w[i] = wage at state i
502+
q: jnp.ndarray = q_default # array of probabilities
503+
ε: float = 0.1 # for ε greedy algorithm
504+
δ: float = 1e-5 # Q-table threshold
505+
lr: float = 0.5 # the learning rate α
506+
T: int = 10000 # maximum periods of accepting
507+
quit_allowed: int = 0 # whether quit is allowed after accepting the wage offer
529508
530509
@jax.jit
531510
def draw_offer_index(model, key):
@@ -546,15 +525,15 @@ def temp_diff(model, qtable, state, accept, key):
546525
547526
def reject_case():
548527
state_next = draw_offer_index(model, key)
549-
TD = c + β * jnp.max(qtable[state_next, :]) - qtable[state, accept]
550-
return TD, state_next
528+
td = c + β * jnp.max(qtable[state_next, :]) - qtable[state, accept]
529+
return td, state_next
551530
552531
def accept_case():
553532
state_next = state
554-
TD = jnp.where(model.quit_allowed == 0,
533+
td = jnp.where(model.quit_allowed == 0,
555534
w[state_next] + β * jnp.max(qtable[state_next, :]) - qtable[state, accept],
556535
w[state_next] + β * qtable[state_next, 1] - qtable[state, accept])
557-
return TD, state_next
536+
return td, state_next
558537
559538
return jax.lax.cond(accept == 0, reject_case, accept_case)
560539
@@ -585,10 +564,10 @@ def run_one_epoch(model, qtable, key, max_times=20000):
585564
accept_count = jnp.where(accept == 1, accept_count + 1, 0)
586565
587566
# Compute temporal difference
588-
TD, s_next = temp_diff(model, qtable, s, accept, td_key)
567+
td, s_next = temp_diff(model, qtable, s, accept, td_key)
589568
590569
# Update qtable
591-
qtable_new = qtable.at[s, accept].add(lr * TD)
570+
qtable_new = qtable.at[s, accept].add(lr * td)
592571
593572
# Calculate error
594573
error = jnp.max(jnp.abs(qtable_new - qtable))
@@ -637,7 +616,7 @@ def compute_error(valfunc, valfunc_VFI):
637616
```
638617
639618
```{code-cell} ipython3
640-
# create an instance of Qlearning_McCall
619+
# create an instance of QLearningMcCall
641620
qlmc = create_qlearning_mccall()
642621
643622
# run
@@ -689,7 +668,7 @@ plt.show()
689668
```
690669
691670
```{code-cell} ipython3
692-
# VFI
671+
# vfi
693672
mcm = create_mccall_model(w=w_new, q=q_new)
694673
valfunc_VFI, converged = vfi(mcm)
695674
valfunc_VFI

0 commit comments

Comments
 (0)