Skip to content

Commit e384f99

Browse files
committed
update all code to jax
1 parent 06cd78a commit e384f99

File tree

1 file changed

+129
-137
lines changed

1 file changed

+129
-137
lines changed

lectures/inventory_dynamics.md

Lines changed: 129 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -36,33 +36,24 @@ follow so-called s-S inventory dynamics.
3636
Such firms
3737

3838
1. wait until inventory falls below some level $s$ and then
39-
1. order sufficient quantities to bring their inventory back up to capacity $S$.
39+
2. order sufficient quantities to bring their inventory back up to capacity $S$.
4040

4141
These kinds of policies are common in practice and also optimal in certain circumstances.
4242

4343
A review of early literature and some macroeconomic implications can be found in {cite}`caplin1985variability`.
4444

4545
Here our main aim is to learn more about simulation, time series and Markov dynamics.
4646

47-
While our Markov environment and many of the concepts we consider are related to those found in our {doc}`lecture on finite Markov chains <finite_markov>`, the state space is a continuum in the current application.
47+
While our Markov environment and many of the concepts we consider are related to those found in our lecture {doc}`<finite_markov>`, the state space is a continuum in the current application.
4848

4949
Let's start with some imports
5050

51-
```{code-cell} ipython3
52-
import jupyter_black
53-
jupyter_black.load(line_length=79)
54-
```
55-
5651
```{code-cell} ipython3
5752
import matplotlib.pyplot as plt
58-
import numpy as np
5953
from typing import NamedTuple
6054
import jax
6155
import jax.numpy as jnp
6256
from jax import random
63-
64-
# from numba import jit, float64, prange
65-
# from numba.experimental import jitclass
6657
```
6758

6859
## Sample Paths
@@ -97,8 +88,8 @@ Here's a class that stores parameters and generates time paths for inventory.
9788

9889
```{code-cell} ipython3
9990
class Firm(NamedTuple):
100-
s: int # restock trigger level
101-
S: int # capacity
91+
s: int # restock trigger level
92+
S: int # capacity
10293
μ: float # shock location parameter
10394
σ: float # shock scale parameter
10495
```
@@ -112,8 +103,7 @@ def sim_inventory_path(firm, x_init, random_keys):
112103
Args:
113104
firm: Firm object
114105
x_init: Initial inventory level
115-
sim_length: Length of simulation
116-
key: JAX random key
106+
random_keys: Array of JAX random keys
117107
118108
Returns:
119109
Array of inventory levels over time
@@ -122,21 +112,11 @@ def sim_inventory_path(firm, x_init, random_keys):
122112
def update_step(carry, key_t):
123113
"""
124114
Single update step
125-
126-
Args:
127-
carry: Current inventory level (x)
128-
key_t: Random key for this time step
129-
130-
Returns:
131-
(new_x, new_x): Updated inventory level (returned twice for scan)
132115
"""
133116
x = carry
134-
135-
# Generate random demand
136117
Z = random.normal(key_t)
137118
D = jnp.exp(firm.μ + firm.σ * Z)
138119
139-
# Update inventory based on (s, S) policy
140120
new_x = jax.lax.cond(
141121
x <= firm.s,
142122
lambda: jnp.maximum(
@@ -243,8 +223,8 @@ for m in range(M):
243223
keys = random.split(random.PRNGKey(m), sim_length - 1)
244224
X = sim_inventory_path(firm, x_init, keys)
245225
ax.plot(X, "b-", lw=1, alpha=0.5)
246-
ax.plot((T,), (X[T + 1],), "ko", alpha=0.5)
247-
sample.append(X[T + 1])
226+
ax.plot((T,), (X[T+1],), "ko", alpha=0.5)
227+
sample.append(X[T+1])
248228
249229
axes[1].set_ylim(ymin, ymax)
250230
@@ -349,90 +329,59 @@ Try different initial conditions to verify that, in the long run, the distributi
349329
Below is one possible solution:
350330

351331
The computations involve a lot of CPU cycles so we have tried to write the
352-
code efficiently.
353-
354-
This meant writing a specialized function rather than using the class above.
332+
code efficiently using `jax.jit` and `jax.vmap` to run on CPU/GPU.
355333

356334
```{code-cell} ipython3
357-
# s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
358-
359-
# @jit(parallel=True)
360-
# def shift_firms_forward(current_inventory_levels, num_periods):
361-
362-
# num_firms = len(current_inventory_levels)
363-
# new_inventory_levels = np.empty(num_firms)
364-
365-
# for f in prange(num_firms):
366-
# x = current_inventory_levels[f]
367-
# for t in range(num_periods):
368-
# Z = np.random.randn()
369-
# D = np.exp(mu + sigma * Z)
370-
# if x <= s:
371-
# x = max(S - D, 0)
372-
# else:
373-
# x = max(x - D, 0)
374-
# new_inventory_levels[f] = x
375-
376-
# return new_inventory_levels
335+
@jax.jit
336+
def simulate_single_firm(x_init, period_keys):
337+
"""
338+
Simulate a single firm forward by num_periods.
339+
340+
Args:
341+
x_init: Initial inventory level for this firm
342+
period_keys: Random key for this firm for each period
343+
"""
344+
345+
def update_step(x, period_key):
346+
Z = random.normal(period_key)
347+
D = jnp.exp(firm.μ + firm.σ * Z)
348+
349+
new_x = jax.lax.cond(
350+
x <= firm.s,
351+
lambda: jnp.maximum(firm.S - D, 0.0),
352+
lambda: jnp.maximum(x - D, 0.0),
353+
)
354+
return (
355+
new_x,
356+
None,
357+
) # Return None for scan accumulator (we don't need it)
358+
359+
# Simulate forward num_periods
360+
final_x, _ = jax.lax.scan(update_step, x_init, period_keys)
361+
362+
return final_x
363+
364+
365+
# Vectorize over all firms using vmap
366+
vectorized_simulate = jax.vmap(simulate_single_firm, in_axes=(0, 0))
377367
```
378368

379369
```{code-cell} ipython3
380-
def shift_firms_forward(firm: Firm, current_inventory_levels, num_periods: int, key: random.PRNGKey):
370+
def shift_firms_forward(firm, current_inventory_levels, num_periods, key):
381371
"""
382372
Shift multiple firms forward by num_periods using JAX vectorization.
383-
384-
Args:
385-
firm: Firm dataclass with parameters s, S, mu, sigma
386-
current_inventory_levels: Array of current inventory levels for each firm
387-
num_periods: Number of periods to simulate forward
388-
key: JAX random key
389-
390373
Returns:
391374
Array of new inventory levels after num_periods
392375
"""
393-
394-
def simulate_single_firm(x_init, firm_key):
395-
"""
396-
Simulate a single firm forward by num_periods.
397-
398-
Args:
399-
x_init: Initial inventory level for this firm
400-
firm_key: Random key for this firm
401-
402-
Returns:
403-
Final inventory level after num_periods
404-
"""
405-
406-
def update_step(x, period_key):
407-
"""Single period update step."""
408-
Z = random.normal(period_key)
409-
D = jnp.exp(firm.mu + firm.sigma * Z)
410-
411-
new_x = jax.lax.cond(
412-
x <= firm.s,
413-
lambda: jnp.maximum(firm.S - D, 0.0),
414-
lambda: jnp.maximum(x - D, 0.0)
415-
)
416-
return new_x, None # Return None for scan accumulator (we don't need it)
417-
418-
# Generate keys for each period
419-
period_keys = random.split(firm_key, num_periods)
420-
421-
# Simulate forward num_periods
422-
final_x, _ = jax.lax.scan(update_step, x_init, period_keys)
423-
424-
return final_x
425-
376+
426377
# Generate independent random keys for each firm
427378
num_firms = len(current_inventory_levels)
428-
firm_keys = random.split(key, num_firms)
429-
430-
# Vectorize over all firms using vmap
431-
vectorized_simulate = jax.vmap(simulate_single_firm, in_axes=(0, 0))
432-
379+
firm_keys = random.split(key, (num_firms, num_periods))
433380
# Run simulation for all firms in parallel
434-
new_inventory_levels = vectorized_simulate(current_inventory_levels, firm_keys)
435-
381+
new_inventory_levels = vectorized_simulate(
382+
current_inventory_levels, firm_keys
383+
)
384+
436385
return new_inventory_levels
437386
```
438387

@@ -442,20 +391,20 @@ num_firms = 50_000
442391
443392
sample_dates = 0, 10, 50, 250, 500, 750
444393
445-
first_diffs = np.diff(sample_dates)
394+
first_diffs = jnp.diff(jnp.array(sample_dates))
446395
447396
fig, ax = plt.subplots()
448397
449-
X = np.full(num_firms, x_init)
398+
X = jnp.full(num_firms, x_init)
450399
451400
current_date = 0
452401
for d in first_diffs:
453-
X = shift_firms_forward(X, d)
402+
X = shift_firms_forward(firm, X, d, random.PRNGKey(d))
454403
current_date += d
455-
plot_kde(X, ax, label=f't = {current_date}')
404+
plot_kde(X, ax, label=f"t = {current_date}")
456405
457-
ax.set_xlabel('inventory')
458-
ax.set_ylabel('probability')
406+
ax.set_xlabel("inventory")
407+
ax.set_ylabel("probability")
459408
ax.legend()
460409
plt.show()
461410
```
@@ -490,51 +439,94 @@ You will need a large sample size to get an accurate reading.
490439

491440
Here is one solution.
492441

493-
Again, the computations are relatively intensive so we have written a a
494-
specialized function rather than using the class above.
442+
Again, the computations are relatively intensive so we have written a
443+
specialized JAX-jitted function and using `jax.vmap` to use parallelization across firms.
495444

496-
We will also use parallelization across firms.
445+
Note the time the routine takes to run, as well as the output.
497446

498447
```{code-cell} ipython3
499-
# TODO: Update this to JAX
500-
@jit(parallel=True)
501-
def compute_freq(sim_length=50, x_init=70, num_firms=1_000_000):
502-
503-
firm_counter = 0 # Records number of firms that restock 2x or more
504-
for m in prange(num_firms):
505-
x = x_init
506-
restock_counter = 0 # Will record number of restocks for firm m
507-
508-
for t in range(sim_length):
509-
Z = np.random.randn()
510-
D = np.exp(mu + sigma * Z)
511-
if x <= s:
512-
x = max(S - D, 0)
513-
restock_counter += 1
514-
else:
515-
x = max(x - D, 0)
516-
517-
if restock_counter > 1:
518-
firm_counter += 1
519-
520-
return firm_counter / num_firms
448+
@jax.jit
449+
def simulate_single_firm(period_keys):
450+
"""
451+
Simulate a single firm and count restocks.
452+
453+
Args:
454+
period_keys: Random key for all the periods
455+
456+
Returns:
457+
1 if firm restocks > 1 times, 0 otherwise
458+
"""
459+
460+
def update_step(carry, period_key):
461+
x, restock_count = carry
462+
Z = random.normal(period_key)
463+
D = jnp.exp(firm.μ + firm.σ * Z)
464+
465+
# Check if we need to restock and update accordingly
466+
def restock_branch():
467+
new_x = jnp.maximum(firm.S - D, 0.0)
468+
new_restock_count = restock_count + 1
469+
return (new_x, new_restock_count)
470+
471+
def no_restock_branch():
472+
new_x = jnp.maximum(x - D, 0.0)
473+
return (new_x, restock_count)
474+
475+
new_carry = jax.lax.cond(
476+
x <= firm.s, restock_branch, no_restock_branch
477+
)
478+
479+
return new_carry, None
480+
481+
# Initial state: (inventory_level, restock_count)
482+
initial_carry = (x_init, 0)
483+
484+
# Simulate through all periods
485+
(final_x, total_restocks), _ = jax.lax.scan(
486+
update_step, initial_carry, period_keys
487+
)
488+
489+
# Return 1 if restocked more than once, 0 otherwise
490+
return jnp.where(total_restocks > 1, 1, 0)
491+
492+
493+
# Vectorize the simulation across all firms
494+
vectorized_simulate = jax.vmap(simulate_single_firm, in_axes=(0,))
521495
```
522496

523-
Note the time the routine takes to run, as well as the output.
497+
```{code-cell} ipython3
498+
def compute_freq(
499+
firm, sim_length=50, x_init=70, num_firms=1_000_000, key=random.PRNGKey(2)
500+
):
501+
"""
502+
Compute the frequency of firms that restock 2 or more times using JAX.
503+
504+
Args:
505+
firm: Firm dataclass
506+
sim_length: Length of simulation for each firm
507+
x_init: Initial inventory level for all firms
508+
num_firms: Number of firms to simulate
509+
key: JAX random key
510+
511+
Returns:
512+
Fraction of firms that restock 2 or more times
513+
"""
514+
# Generate independent random keys for each firm
515+
firm_keys = random.split(key, (num_firms, sim_length))
516+
# Run simulation for all firms
517+
restock_indicators = vectorized_simulate(firm_keys)
518+
# Compute frequency (fraction of firms that restocked > 1 times)
519+
frequency = jnp.mean(restock_indicators)
520+
return frequency
521+
```
524522

525523
```{code-cell} ipython3
526524
%%time
527525
528-
freq = compute_freq()
526+
freq = compute_freq(firm)
529527
print(f"Frequency of at least two stock outs = {freq}")
530528
```
531529

532-
Try switching the `parallel` flag to `False` in the jitted function
533-
above.
534-
535-
Depending on your system, the difference can be substantial.
536-
537-
(On our desktop machine, the speed up is by a factor of 5.)
538530

539531
```{solution-end}
540532
```

0 commit comments

Comments
 (0)