@@ -36,33 +36,24 @@ follow so-called s-S inventory dynamics.
3636Such firms
3737
38381 . wait until inventory falls below some level $s$ and then
39- 1 . order sufficient quantities to bring their inventory back up to capacity $S$.
39+ 2 . order sufficient quantities to bring their inventory back up to capacity $S$.
4040
4141These kinds of policies are common in practice and also optimal in certain circumstances.
4242
4343A review of early literature and some macroeconomic implications can be found in {cite}` caplin1985variability ` .
4444
4545Here our main aim is to learn more about simulation, time series and Markov dynamics.
4646
47- While our Markov environment and many of the concepts we consider are related to those found in our {doc}` lecture on finite Markov chains <finite_markov>` , the state space is a continuum in the current application.
47+ While our Markov environment and many of the concepts we consider are related to those found in our lecture {doc}` <finite_markov> ` , the state space is a continuum in the current application.
4848
4949Let's start with some imports
5050
51- ``` {code-cell} ipython3
52- import jupyter_black
53- jupyter_black.load(line_length=79)
54- ```
55-
5651``` {code-cell} ipython3
5752import matplotlib.pyplot as plt
58- import numpy as np
5953from typing import NamedTuple
6054import jax
6155import jax.numpy as jnp
6256from jax import random
63-
64- # from numba import jit, float64, prange
65- # from numba.experimental import jitclass
6657```
6758
6859## Sample Paths
@@ -97,8 +88,8 @@ Here's a class that stores parameters and generates time paths for inventory.
9788
9889``` {code-cell} ipython3
9990class Firm(NamedTuple):
100- s: int # restock trigger level
101- S: int # capacity
91+ s: int # restock trigger level
92+ S: int # capacity
10293 μ: float # shock location parameter
10394 σ: float # shock scale parameter
10495```
@@ -112,8 +103,7 @@ def sim_inventory_path(firm, x_init, random_keys):
112103 Args:
113104 firm: Firm object
114105 x_init: Initial inventory level
115- sim_length: Length of simulation
116- key: JAX random key
106+ random_keys: Array of JAX random keys
117107
118108 Returns:
119109 Array of inventory levels over time
@@ -122,21 +112,11 @@ def sim_inventory_path(firm, x_init, random_keys):
122112 def update_step(carry, key_t):
123113 """
124114 Single update step
125-
126- Args:
127- carry: Current inventory level (x)
128- key_t: Random key for this time step
129-
130- Returns:
131- (new_x, new_x): Updated inventory level (returned twice for scan)
132115 """
133116 x = carry
134-
135- # Generate random demand
136117 Z = random.normal(key_t)
137118 D = jnp.exp(firm.μ + firm.σ * Z)
138119
139- # Update inventory based on (s, S) policy
140120 new_x = jax.lax.cond(
141121 x <= firm.s,
142122 lambda: jnp.maximum(
@@ -243,8 +223,8 @@ for m in range(M):
243223 keys = random.split(random.PRNGKey(m), sim_length - 1)
244224 X = sim_inventory_path(firm, x_init, keys)
245225 ax.plot(X, "b-", lw=1, alpha=0.5)
246- ax.plot((T,), (X[T + 1],), "ko", alpha=0.5)
247- sample.append(X[T + 1])
226+ ax.plot((T,), (X[T+ 1],), "ko", alpha=0.5)
227+ sample.append(X[T+ 1])
248228
249229axes[1].set_ylim(ymin, ymax)
250230
@@ -349,90 +329,59 @@ Try different initial conditions to verify that, in the long run, the distributi
349329Below is one possible solution:
350330
351331The computations involve a lot of CPU cycles so we have tried to write the
352- code efficiently.
353-
354- This meant writing a specialized function rather than using the class above.
332+ code efficiently using ` jax.jit ` and ` jax.vmap ` to run on CPU/GPU.
355333
356334``` {code-cell} ipython3
357- # s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
358-
359- # @jit(parallel=True)
360- # def shift_firms_forward(current_inventory_levels, num_periods):
361-
362- # num_firms = len(current_inventory_levels)
363- # new_inventory_levels = np.empty(num_firms)
364-
365- # for f in prange(num_firms):
366- # x = current_inventory_levels[f]
367- # for t in range(num_periods):
368- # Z = np.random.randn()
369- # D = np.exp(mu + sigma * Z)
370- # if x <= s:
371- # x = max(S - D, 0)
372- # else:
373- # x = max(x - D, 0)
374- # new_inventory_levels[f] = x
375-
376- # return new_inventory_levels
335+ @jax.jit
336+ def simulate_single_firm(x_init, period_keys):
337+ """
338+ Simulate a single firm forward by num_periods.
339+
340+ Args:
341+ x_init: Initial inventory level for this firm
342+ period_keys: Random key for this firm for each period
343+ """
344+
345+ def update_step(x, period_key):
346+ Z = random.normal(period_key)
347+ D = jnp.exp(firm.μ + firm.σ * Z)
348+
349+ new_x = jax.lax.cond(
350+ x <= firm.s,
351+ lambda: jnp.maximum(firm.S - D, 0.0),
352+ lambda: jnp.maximum(x - D, 0.0),
353+ )
354+ return (
355+ new_x,
356+ None,
357+ ) # Return None for scan accumulator (we don't need it)
358+
359+ # Simulate forward num_periods
360+ final_x, _ = jax.lax.scan(update_step, x_init, period_keys)
361+
362+ return final_x
363+
364+
365+ # Vectorize over all firms using vmap
366+ vectorized_simulate = jax.vmap(simulate_single_firm, in_axes=(0, 0))
377367```
378368
379369``` {code-cell} ipython3
380- def shift_firms_forward(firm: Firm , current_inventory_levels, num_periods: int , key: random.PRNGKey ):
370+ def shift_firms_forward(firm, current_inventory_levels, num_periods, key):
381371 """
382372 Shift multiple firms forward by num_periods using JAX vectorization.
383-
384- Args:
385- firm: Firm dataclass with parameters s, S, mu, sigma
386- current_inventory_levels: Array of current inventory levels for each firm
387- num_periods: Number of periods to simulate forward
388- key: JAX random key
389-
390373 Returns:
391374 Array of new inventory levels after num_periods
392375 """
393-
394- def simulate_single_firm(x_init, firm_key):
395- """
396- Simulate a single firm forward by num_periods.
397-
398- Args:
399- x_init: Initial inventory level for this firm
400- firm_key: Random key for this firm
401-
402- Returns:
403- Final inventory level after num_periods
404- """
405-
406- def update_step(x, period_key):
407- """Single period update step."""
408- Z = random.normal(period_key)
409- D = jnp.exp(firm.mu + firm.sigma * Z)
410-
411- new_x = jax.lax.cond(
412- x <= firm.s,
413- lambda: jnp.maximum(firm.S - D, 0.0),
414- lambda: jnp.maximum(x - D, 0.0)
415- )
416- return new_x, None # Return None for scan accumulator (we don't need it)
417-
418- # Generate keys for each period
419- period_keys = random.split(firm_key, num_periods)
420-
421- # Simulate forward num_periods
422- final_x, _ = jax.lax.scan(update_step, x_init, period_keys)
423-
424- return final_x
425-
376+
426377 # Generate independent random keys for each firm
427378 num_firms = len(current_inventory_levels)
428- firm_keys = random.split(key, num_firms)
429-
430- # Vectorize over all firms using vmap
431- vectorized_simulate = jax.vmap(simulate_single_firm, in_axes=(0, 0))
432-
379+ firm_keys = random.split(key, (num_firms, num_periods))
433380 # Run simulation for all firms in parallel
434- new_inventory_levels = vectorized_simulate(current_inventory_levels, firm_keys)
435-
381+ new_inventory_levels = vectorized_simulate(
382+ current_inventory_levels, firm_keys
383+ )
384+
436385 return new_inventory_levels
437386```
438387
@@ -442,20 +391,20 @@ num_firms = 50_000
442391
443392sample_dates = 0, 10, 50, 250, 500, 750
444393
445- first_diffs = np .diff(sample_dates)
394+ first_diffs = jnp .diff(jnp.array( sample_dates) )
446395
447396fig, ax = plt.subplots()
448397
449- X = np .full(num_firms, x_init)
398+ X = jnp .full(num_firms, x_init)
450399
451400current_date = 0
452401for d in first_diffs:
453- X = shift_firms_forward(X, d)
402+ X = shift_firms_forward(firm, X, d, random.PRNGKey(d) )
454403 current_date += d
455- plot_kde(X, ax, label=f' t = {current_date}' )
404+ plot_kde(X, ax, label=f" t = {current_date}" )
456405
457- ax.set_xlabel(' inventory' )
458- ax.set_ylabel(' probability' )
406+ ax.set_xlabel(" inventory" )
407+ ax.set_ylabel(" probability" )
459408ax.legend()
460409plt.show()
461410```
@@ -490,51 +439,94 @@ You will need a large sample size to get an accurate reading.
490439
491440Here is one solution.
492441
493- Again, the computations are relatively intensive so we have written a a
494- specialized function rather than using the class above .
442+ Again, the computations are relatively intensive so we have written a
443+ specialized JAX-jitted function and using ` jax.vmap ` to use parallelization across firms .
495444
496- We will also use parallelization across firms .
445+ Note the time the routine takes to run, as well as the output .
497446
498447``` {code-cell} ipython3
499- # TODO: Update this to JAX
500- @jit(parallel=True)
501- def compute_freq(sim_length=50, x_init=70, num_firms=1_000_000):
502-
503- firm_counter = 0 # Records number of firms that restock 2x or more
504- for m in prange(num_firms):
505- x = x_init
506- restock_counter = 0 # Will record number of restocks for firm m
507-
508- for t in range(sim_length):
509- Z = np.random.randn()
510- D = np.exp(mu + sigma * Z)
511- if x <= s:
512- x = max(S - D, 0)
513- restock_counter += 1
514- else:
515- x = max(x - D, 0)
516-
517- if restock_counter > 1:
518- firm_counter += 1
519-
520- return firm_counter / num_firms
448+ @jax.jit
449+ def simulate_single_firm(period_keys):
450+ """
451+ Simulate a single firm and count restocks.
452+
453+ Args:
454+ period_keys: Random key for all the periods
455+
456+ Returns:
457+ 1 if firm restocks > 1 times, 0 otherwise
458+ """
459+
460+ def update_step(carry, period_key):
461+ x, restock_count = carry
462+ Z = random.normal(period_key)
463+ D = jnp.exp(firm.μ + firm.σ * Z)
464+
465+ # Check if we need to restock and update accordingly
466+ def restock_branch():
467+ new_x = jnp.maximum(firm.S - D, 0.0)
468+ new_restock_count = restock_count + 1
469+ return (new_x, new_restock_count)
470+
471+ def no_restock_branch():
472+ new_x = jnp.maximum(x - D, 0.0)
473+ return (new_x, restock_count)
474+
475+ new_carry = jax.lax.cond(
476+ x <= firm.s, restock_branch, no_restock_branch
477+ )
478+
479+ return new_carry, None
480+
481+ # Initial state: (inventory_level, restock_count)
482+ initial_carry = (x_init, 0)
483+
484+ # Simulate through all periods
485+ (final_x, total_restocks), _ = jax.lax.scan(
486+ update_step, initial_carry, period_keys
487+ )
488+
489+ # Return 1 if restocked more than once, 0 otherwise
490+ return jnp.where(total_restocks > 1, 1, 0)
491+
492+
493+ # Vectorize the simulation across all firms
494+ vectorized_simulate = jax.vmap(simulate_single_firm, in_axes=(0,))
521495```
522496
523- Note the time the routine takes to run, as well as the output.
497+ ``` {code-cell} ipython3
498+ def compute_freq(
499+ firm, sim_length=50, x_init=70, num_firms=1_000_000, key=random.PRNGKey(2)
500+ ):
501+ """
502+ Compute the frequency of firms that restock 2 or more times using JAX.
503+
504+ Args:
505+ firm: Firm dataclass
506+ sim_length: Length of simulation for each firm
507+ x_init: Initial inventory level for all firms
508+ num_firms: Number of firms to simulate
509+ key: JAX random key
510+
511+ Returns:
512+ Fraction of firms that restock 2 or more times
513+ """
514+ # Generate independent random keys for each firm
515+ firm_keys = random.split(key, (num_firms, sim_length))
516+ # Run simulation for all firms
517+ restock_indicators = vectorized_simulate(firm_keys)
518+ # Compute frequency (fraction of firms that restocked > 1 times)
519+ frequency = jnp.mean(restock_indicators)
520+ return frequency
521+ ```
524522
525523``` {code-cell} ipython3
526524%%time
527525
528- freq = compute_freq()
526+ freq = compute_freq(firm )
529527print(f"Frequency of at least two stock outs = {freq}")
530528```
531529
532- Try switching the ` parallel ` flag to ` False ` in the jitted function
533- above.
534-
535- Depending on your system, the difference can be substantial.
536-
537- (On our desktop machine, the speed up is by a factor of 5.)
538530
539531``` {solution-end}
540532```
0 commit comments