Skip to content

Commit 6f63d18

Browse files
committed
update to jax code
1 parent 0184d5e commit 6f63d18

File tree

1 file changed

+198
-96
lines changed

1 file changed

+198
-96
lines changed

lectures/inventory_dynamics.md

Lines changed: 198 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ jupytext:
33
text_representation:
44
extension: .md
55
format_name: myst
6+
format_version: 0.13
7+
jupytext_version: 1.16.7
68
kernelspec:
7-
display_name: Python 3
9+
display_name: Python 3 (ipykernel)
810
language: python
911
name: python3
1012
---
@@ -46,11 +48,21 @@ While our Markov environment and many of the concepts we consider are related to
4648

4749
Let's start with some imports
4850

51+
```{code-cell} ipython3
52+
import jupyter_black
53+
jupyter_black.load(line_length=79)
54+
```
55+
4956
```{code-cell} ipython3
5057
import matplotlib.pyplot as plt
5158
import numpy as np
52-
from numba import jit, float64, prange
53-
from numba.experimental import jitclass
59+
from typing import NamedTuple
60+
import jax
61+
import jax.numpy as jnp
62+
from jax import random
63+
64+
# from numba import jit, float64, prange
65+
# from numba.experimental import jitclass
5466
```
5567

5668
## Sample Paths
@@ -83,64 +95,91 @@ and standard normal.
8395

8496
Here's a class that stores parameters and generates time paths for inventory.
8597

86-
```{code-cell} python3
87-
firm_data = [
88-
('s', float64), # restock trigger level
89-
('S', float64), # capacity
90-
('mu', float64), # shock location parameter
91-
('sigma', float64) # shock scale parameter
92-
]
93-
94-
95-
@jitclass(firm_data)
96-
class Firm:
97-
98-
def __init__(self, s=10, S=100, mu=1.0, sigma=0.5):
99-
100-
self.s, self.S, self.mu, self.sigma = s, S, mu, sigma
101-
102-
def update(self, x):
103-
"Update the state from t to t+1 given current state x."
104-
105-
Z = np.random.randn()
106-
D = np.exp(self.mu + self.sigma * Z)
107-
if x <= self.s:
108-
return max(self.S - D, 0)
109-
else:
110-
return max(x - D, 0)
98+
```{code-cell} ipython3
99+
class Firm(NamedTuple):
100+
s: int # restock trigger level
101+
S: int # capacity
102+
μ: float # shock location parameter
103+
σ: float # shock scale parameter
104+
```
111105

112-
def sim_inventory_path(self, x_init, sim_length):
106+
```{code-cell} ipython3
107+
@jax.jit
108+
def sim_inventory_path(firm, x_init, random_keys):
109+
"""
110+
Simulate inventory path.
111+
112+
Args:
113+
firm: Firm object
114+
x_init: Initial inventory level
115+
sim_length: Length of simulation
116+
key: JAX random key
117+
118+
Returns:
119+
Array of inventory levels over time
120+
"""
121+
122+
def update_step(carry, key_t):
123+
"""
124+
Single update step
125+
126+
Args:
127+
carry: Current inventory level (x)
128+
key_t: Random key for this time step
129+
130+
Returns:
131+
(new_x, new_x): Updated inventory level (returned twice for scan)
132+
"""
133+
x = carry
134+
135+
# Generate random demand
136+
Z = random.normal(key_t)
137+
D = jnp.exp(firm.μ + firm.σ * Z)
138+
139+
# Update inventory based on (s, S) policy
140+
new_x = jax.lax.cond(
141+
x <= firm.s,
142+
lambda: jnp.maximum(
143+
firm.S - D, 0.0
144+
), # Reorder to S, then subtract demand
145+
lambda: jnp.maximum(x - D, 0.0), # Just subtract demand
146+
)
147+
148+
return new_x, new_x
149+
150+
# Use scan to iterate through time steps
151+
final_x, X_path = jax.lax.scan(update_step, x_init, random_keys)
152+
153+
# Prepend initial value
154+
X = jnp.concatenate([jnp.array([x_init]), X_path])
155+
156+
return X
157+
```
113158

114-
X = np.empty(sim_length)
115-
X[0] = x_init
159+
```{code-cell} ipython3
160+
firm = Firm(s=10, S=100, μ=1.0, σ=0.5)
161+
```
116162

117-
for t in range(sim_length-1):
118-
X[t+1] = self.update(X[t])
119-
return X
163+
```{code-cell} ipython3
164+
sim_length = 100
165+
x_init = 50
166+
keys = random.split(random.PRNGKey(21), sim_length - 1)
167+
X = sim_inventory_path(firm, x_init, keys)
120168
```
121169

122170
Let's run a first simulation, of a single path:
123171

124172
```{code-cell} ipython3
125-
firm = Firm()
126-
127173
s, S = firm.s, firm.S
128-
sim_length = 100
129-
x_init = 50
130-
131-
X = firm.sim_inventory_path(x_init, sim_length)
132174
133175
fig, ax = plt.subplots()
134-
bbox = (0., 1.02, 1., .102)
135-
legend_args = {'ncol': 3,
136-
'bbox_to_anchor': bbox,
137-
'loc': 3,
138-
'mode': 'expand'}
176+
bbox = (0.0, 1.02, 1.0, 0.102)
177+
legend_args = {"ncol": 3, "bbox_to_anchor": bbox, "loc": 3, "mode": "expand"}
139178
140179
ax.plot(X, label="inventory")
141-
ax.plot(np.full(sim_length, s), 'k--', label="$s$")
142-
ax.plot(np.full(sim_length, S), 'k-', label="$S$")
143-
ax.set_ylim(0, S+10)
180+
ax.plot(jnp.full(sim_length, s), "k--", label="$s$")
181+
ax.plot(jnp.full(sim_length, S), "k-", label="$S$")
182+
ax.set_ylim(0, S + 10)
144183
ax.set_xlabel("time")
145184
ax.legend(**legend_args)
146185
@@ -151,17 +190,18 @@ Now let's simulate multiple paths in order to build a more complete picture of
151190
the probabilities of different outcomes:
152191

153192
```{code-cell} ipython3
154-
sim_length=200
193+
sim_length = 200
155194
fig, ax = plt.subplots()
156195
157-
ax.plot(np.full(sim_length, s), 'k--', label="$s$")
158-
ax.plot(np.full(sim_length, S), 'k-', label="$S$")
159-
ax.set_ylim(0, S+10)
196+
ax.plot(jnp.full(sim_length, s), "k--", label="$s$")
197+
ax.plot(jnp.full(sim_length, S), "k-", label="$S$")
198+
ax.set_ylim(0, S + 10)
160199
ax.legend(**legend_args)
161200
162201
for i in range(400):
163-
X = firm.sim_inventory_path(x_init, sim_length)
164-
ax.plot(X, 'b', alpha=0.2, lw=0.5)
202+
keys = random.split(random.PRNGKey(i), sim_length - 1)
203+
X = sim_inventory_path(firm, x_init, keys)
204+
ax.plot(X, "b", alpha=0.2, lw=0.5)
165205
166206
plt.show()
167207
```
@@ -192,27 +232,30 @@ for ax in axes:
192232
ax = axes[0]
193233
194234
ax.set_ylim(ymin, ymax)
195-
ax.set_ylabel('$X_t$', fontsize=16)
235+
ax.set_ylabel("$X_t$", fontsize=16)
196236
ax.vlines((T,), -1.5, 1.5)
197237
198238
ax.set_xticks((T,))
199-
ax.set_xticklabels((r'$T$',))
239+
ax.set_xticklabels((r"$T$",))
200240
201-
sample = np.empty(M)
241+
sample = []
202242
for m in range(M):
203-
X = firm.sim_inventory_path(x_init, 2 * T)
204-
ax.plot(X, 'b-', lw=1, alpha=0.5)
205-
ax.plot((T,), (X[T+1],), 'ko', alpha=0.5)
206-
sample[m] = X[T+1]
243+
keys = random.split(random.PRNGKey(m), sim_length - 1)
244+
X = sim_inventory_path(firm, x_init, keys)
245+
ax.plot(X, "b-", lw=1, alpha=0.5)
246+
ax.plot((T,), (X[T + 1],), "ko", alpha=0.5)
247+
sample.append(X[T + 1])
207248
208249
axes[1].set_ylim(ymin, ymax)
209250
210-
axes[1].hist(sample,
211-
bins=16,
212-
density=True,
213-
orientation='horizontal',
214-
histtype='bar',
215-
alpha=0.5)
251+
axes[1].hist(
252+
sample,
253+
bins=16,
254+
density=True,
255+
orientation="horizontal",
256+
histtype="bar",
257+
alpha=0.5,
258+
)
216259
217260
plt.show()
218261
```
@@ -225,16 +268,13 @@ M = 50_000
225268
226269
fig, ax = plt.subplots()
227270
228-
sample = np.empty(M)
271+
sample = []
229272
for m in range(M):
230-
X = firm.sim_inventory_path(x_init, T+1)
231-
sample[m] = X[T]
273+
keys = random.split(random.PRNGKey(m), T)
274+
X = sim_inventory_path(firm, x_init, keys)
275+
sample.append(X[T])
232276
233-
ax.hist(sample,
234-
bins=36,
235-
density=True,
236-
histtype='bar',
237-
alpha=0.75)
277+
ax.hist(sample, bins=36, density=True, histtype="bar", alpha=0.75)
238278
239279
plt.show()
240280
```
@@ -255,14 +295,15 @@ We will use a kernel density estimator from [scikit-learn](https://scikit-learn.
255295
```{code-cell} ipython3
256296
from sklearn.neighbors import KernelDensity
257297
258-
def plot_kde(sample, ax, label=''):
259298
299+
def plot_kde(sample, ax, label=""):
300+
sample = jnp.array(sample)
260301
xmin, xmax = 0.9 * min(sample), 1.1 * max(sample)
261-
xgrid = np.linspace(xmin, xmax, 200)
262-
kde = KernelDensity(kernel='gaussian').fit(sample[:, None])
302+
xgrid = jnp.linspace(xmin, xmax, 200)
303+
kde = KernelDensity(kernel="gaussian").fit(sample[:, None])
263304
log_dens = kde.score_samples(xgrid[:, None])
264305
265-
ax.plot(xgrid, np.exp(log_dens), label=label)
306+
ax.plot(xgrid, jnp.exp(log_dens), label=label)
266307
```
267308

268309
```{code-cell} ipython3
@@ -313,25 +354,85 @@ code efficiently.
313354
This meant writing a specialized function rather than using the class above.
314355

315356
```{code-cell} ipython3
316-
s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
317-
318-
@jit(parallel=True)
319-
def shift_firms_forward(current_inventory_levels, num_periods):
357+
# s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
358+
359+
# @jit(parallel=True)
360+
# def shift_firms_forward(current_inventory_levels, num_periods):
361+
362+
# num_firms = len(current_inventory_levels)
363+
# new_inventory_levels = np.empty(num_firms)
364+
365+
# for f in prange(num_firms):
366+
# x = current_inventory_levels[f]
367+
# for t in range(num_periods):
368+
# Z = np.random.randn()
369+
# D = np.exp(mu + sigma * Z)
370+
# if x <= s:
371+
# x = max(S - D, 0)
372+
# else:
373+
# x = max(x - D, 0)
374+
# new_inventory_levels[f] = x
375+
376+
# return new_inventory_levels
377+
```
320378

379+
```{code-cell} ipython3
380+
def shift_firms_forward(firm: Firm, current_inventory_levels, num_periods: int, key: random.PRNGKey):
381+
"""
382+
Shift multiple firms forward by num_periods using JAX vectorization.
383+
384+
Args:
385+
firm: Firm dataclass with parameters s, S, mu, sigma
386+
current_inventory_levels: Array of current inventory levels for each firm
387+
num_periods: Number of periods to simulate forward
388+
key: JAX random key
389+
390+
Returns:
391+
Array of new inventory levels after num_periods
392+
"""
393+
394+
def simulate_single_firm(x_init, firm_key):
395+
"""
396+
Simulate a single firm forward by num_periods.
397+
398+
Args:
399+
x_init: Initial inventory level for this firm
400+
firm_key: Random key for this firm
401+
402+
Returns:
403+
Final inventory level after num_periods
404+
"""
405+
406+
def update_step(x, period_key):
407+
"""Single period update step."""
408+
Z = random.normal(period_key)
409+
D = jnp.exp(firm.mu + firm.sigma * Z)
410+
411+
new_x = jax.lax.cond(
412+
x <= firm.s,
413+
lambda: jnp.maximum(firm.S - D, 0.0),
414+
lambda: jnp.maximum(x - D, 0.0)
415+
)
416+
return new_x, None # Return None for scan accumulator (we don't need it)
417+
418+
# Generate keys for each period
419+
period_keys = random.split(firm_key, num_periods)
420+
421+
# Simulate forward num_periods
422+
final_x, _ = jax.lax.scan(update_step, x_init, period_keys)
423+
424+
return final_x
425+
426+
# Generate independent random keys for each firm
321427
num_firms = len(current_inventory_levels)
322-
new_inventory_levels = np.empty(num_firms)
323-
324-
for f in prange(num_firms):
325-
x = current_inventory_levels[f]
326-
for t in range(num_periods):
327-
Z = np.random.randn()
328-
D = np.exp(mu + sigma * Z)
329-
if x <= s:
330-
x = max(S - D, 0)
331-
else:
332-
x = max(x - D, 0)
333-
new_inventory_levels[f] = x
334-
428+
firm_keys = random.split(key, num_firms)
429+
430+
# Vectorize over all firms using vmap
431+
vectorized_simulate = jax.vmap(simulate_single_firm, in_axes=(0, 0))
432+
433+
# Run simulation for all firms in parallel
434+
new_inventory_levels = vectorized_simulate(current_inventory_levels, firm_keys)
435+
335436
return new_inventory_levels
336437
```
337438

@@ -395,6 +496,7 @@ specialized function rather than using the class above.
395496
We will also use parallelization across firms.
396497

397498
```{code-cell} ipython3
499+
# TODO: Update this to JAX
398500
@jit(parallel=True)
399501
def compute_freq(sim_length=50, x_init=70, num_firms=1_000_000):
400502

0 commit comments

Comments
 (0)