Skip to content

Commit c69d3c8

Browse files
committed
fix suggestions
1 parent 2b7c4fe commit c69d3c8

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

lectures/inventory_dynamics.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,12 @@ While our Markov environment and many of the concepts we consider are related to
4949
Let's start with some imports
5050

5151
```{code-cell} ipython3
52-
import matplotlib.pyplot as plt
5352
from typing import NamedTuple
5453
import jax
5554
import jax.numpy as jnp
5655
from jax import random
56+
import matplotlib.pyplot as plt
57+
from sklearn.neighbors import KernelDensity
5758
```
5859

5960
## Sample paths
@@ -103,7 +104,7 @@ def sim_inventory_path(firm, x_init, random_keys):
103104
Args:
104105
firm: Firm object
105106
x_init: Initial inventory level
106-
random_keys: Array of JAX random keys
107+
random_keys: Array of JAX random keys of length sim_length - 1.
107108
108109
Returns:
109110
Array of inventory levels over time
@@ -143,6 +144,7 @@ firm = Firm(s=10, S=100, μ=1.0, σ=0.5)
143144
```{code-cell} ipython3
144145
sim_length = 100
145146
x_init = 50
147+
# Generate `sim_length-1` keys as `x_init` will be first array element
146148
keys = random.split(random.PRNGKey(21), sim_length - 1)
147149
X = sim_inventory_path(firm, x_init, keys)
148150
```
@@ -179,6 +181,7 @@ ax.set_ylim(0, S + 10)
179181
ax.legend(**legend_args)
180182
181183
for i in range(400):
184+
# Generate `sim_length-1` keys as `x_init` will be first array element
182185
keys = random.split(random.PRNGKey(i), sim_length - 1)
183186
X = sim_inventory_path(firm, x_init, keys)
184187
ax.plot(X, "b", alpha=0.2, lw=0.5)
@@ -254,6 +257,8 @@ for m in range(M):
254257
X = sim_inventory_path(firm, x_init, keys)
255258
sample.append(X[T])
256259
260+
# Convert to JAX array
261+
sample = jnp.array(sample)
257262
ax.hist(sample, bins=36, density=True, histtype="bar", alpha=0.75)
258263
259264
plt.show()
@@ -273,11 +278,7 @@ They are preferable to histograms when the distribution being estimated is likel
273278
We will use a kernel density estimator from [scikit-learn](https://scikit-learn.org/stable/)
274279

275280
```{code-cell} ipython3
276-
from sklearn.neighbors import KernelDensity
277-
278-
279281
def plot_kde(sample, ax, label=""):
280-
sample = jnp.array(sample)
281282
xmin, xmax = 0.9 * min(sample), 1.1 * max(sample)
282283
xgrid = jnp.linspace(xmin, xmax, 200)
283284
kde = KernelDensity(kernel="gaussian").fit(sample[:, None])
@@ -399,7 +400,7 @@ X = jnp.full(num_firms, x_init)
399400
400401
current_date = 0
401402
for d in first_diffs:
402-
X = shift_firms_forward(firm, X, d, random.PRNGKey(d))
403+
X = shift_firms_forward(firm, X, d, random.PRNGKey(current_date + 1))
403404
current_date += d
404405
plot_kde(X, ax, label=f"t = {current_date}")
405406

0 commit comments

Comments
 (0)