@@ -49,11 +49,12 @@ While our Markov environment and many of the concepts we consider are related to
4949Let's start with some imports
5050
5151``` {code-cell} ipython3
52- import matplotlib.pyplot as plt
5352from typing import NamedTuple
5453import jax
5554import jax.numpy as jnp
5655from jax import random
56+ import matplotlib.pyplot as plt
57+ from sklearn.neighbors import KernelDensity
5758```
5859
5960## Sample paths
@@ -103,7 +104,7 @@ def sim_inventory_path(firm, x_init, random_keys):
103104 Args:
104105 firm: Firm object
105106 x_init: Initial inventory level
106- random_keys: Array of JAX random keys
107+ random_keys: Array of JAX random keys of length sim_length - 1.
107108
108109 Returns:
109110 Array of inventory levels over time
@@ -143,6 +144,7 @@ firm = Firm(s=10, S=100, μ=1.0, σ=0.5)
143144``` {code-cell} ipython3
144145sim_length = 100
145146x_init = 50
147+ # Generate `sim_length-1` keys as `x_init` will be first array element
146148keys = random.split(random.PRNGKey(21), sim_length - 1)
147149X = sim_inventory_path(firm, x_init, keys)
148150```
@@ -179,6 +181,7 @@ ax.set_ylim(0, S + 10)
179181ax.legend(**legend_args)
180182
181183for i in range(400):
184+ # Generate `sim_length-1` keys as `x_init` will be first array element
182185 keys = random.split(random.PRNGKey(i), sim_length - 1)
183186 X = sim_inventory_path(firm, x_init, keys)
184187 ax.plot(X, "b", alpha=0.2, lw=0.5)
@@ -254,6 +257,8 @@ for m in range(M):
254257 X = sim_inventory_path(firm, x_init, keys)
255258 sample.append(X[T])
256259
260+ # Convert to JAX array
261+ sample = jnp.array(sample)
257262ax.hist(sample, bins=36, density=True, histtype="bar", alpha=0.75)
258263
259264plt.show()
@@ -273,11 +278,7 @@ They are preferable to histograms when the distribution being estimated is likel
273278We will use a kernel density estimator from [ scikit-learn] ( https://scikit-learn.org/stable/ )
274279
275280``` {code-cell} ipython3
276- from sklearn.neighbors import KernelDensity
277-
278-
279281def plot_kde(sample, ax, label=""):
280- sample = jnp.array(sample)
281282 xmin, xmax = 0.9 * min(sample), 1.1 * max(sample)
282283 xgrid = jnp.linspace(xmin, xmax, 200)
283284 kde = KernelDensity(kernel="gaussian").fit(sample[:, None])
@@ -399,7 +400,7 @@ X = jnp.full(num_firms, x_init)
399400
400401current_date = 0
401402for d in first_diffs:
402- X = shift_firms_forward(firm, X, d, random.PRNGKey(d ))
403+ X = shift_firms_forward(firm, X, d, random.PRNGKey(current_date + 1 ))
403404 current_date += d
404405 plot_kde(X, ax, label=f"t = {current_date}")
405406
0 commit comments