@@ -3,8 +3,10 @@ jupytext:
33 text_representation :
44 extension : .md
55 format_name : myst
6+ format_version : 0.13
7+ jupytext_version : 1.16.7
68kernelspec :
7- display_name : Python 3
9+ display_name : Python 3 (ipykernel)
810 language : python
911 name : python3
1012---
@@ -46,11 +48,21 @@ While our Markov environment and many of the concepts we consider are related to
4648
4749Let's start with some imports
4850
51+ ``` {code-cell} ipython3
52+ import jupyter_black
53+ jupyter_black.load(line_length=79)
54+ ```
55+
4956``` {code-cell} ipython3
5057import matplotlib.pyplot as plt
5158import numpy as np
52- from numba import jit, float64, prange
53- from numba.experimental import jitclass
59+ from typing import NamedTuple
60+ import jax
61+ import jax.numpy as jnp
62+ from jax import random
63+
64+ # from numba import jit, float64, prange
65+ # from numba.experimental import jitclass
5466```
5567
5668## Sample Paths
@@ -83,64 +95,91 @@ and standard normal.
8395
8496Here's a class that stores parameters and generates time paths for inventory.
8597
86- ``` {code-cell} python3
87- firm_data = [
88- ('s', float64), # restock trigger level
89- ('S', float64), # capacity
90- ('mu', float64), # shock location parameter
91- ('sigma', float64) # shock scale parameter
92- ]
93-
94-
95- @jitclass(firm_data)
96- class Firm:
97-
98- def __init__(self, s=10, S=100, mu=1.0, sigma=0.5):
99-
100- self.s, self.S, self.mu, self.sigma = s, S, mu, sigma
101-
102- def update(self, x):
103- "Update the state from t to t+1 given current state x."
104-
105- Z = np.random.randn()
106- D = np.exp(self.mu + self.sigma * Z)
107- if x <= self.s:
108- return max(self.S - D, 0)
109- else:
110- return max(x - D, 0)
98+ ``` {code-cell} ipython3
99+ class Firm(NamedTuple):
100+ s: int # restock trigger level
101+ S: int # capacity
102+ μ: float # shock location parameter
103+ σ: float # shock scale parameter
104+ ```
111105
112- def sim_inventory_path(self, x_init, sim_length):
106+ ``` {code-cell} ipython3
107+ @jax.jit
108+ def sim_inventory_path(firm, x_init, random_keys):
109+ """
110+ Simulate inventory path.
111+
112+ Args:
113+ firm: Firm object
114+ x_init: Initial inventory level
115+ sim_length: Length of simulation
116+ key: JAX random key
117+
118+ Returns:
119+ Array of inventory levels over time
120+ """
121+
122+ def update_step(carry, key_t):
123+ """
124+ Single update step
125+
126+ Args:
127+ carry: Current inventory level (x)
128+ key_t: Random key for this time step
129+
130+ Returns:
131+ (new_x, new_x): Updated inventory level (returned twice for scan)
132+ """
133+ x = carry
134+
135+ # Generate random demand
136+ Z = random.normal(key_t)
137+ D = jnp.exp(firm.μ + firm.σ * Z)
138+
139+ # Update inventory based on (s, S) policy
140+ new_x = jax.lax.cond(
141+ x <= firm.s,
142+ lambda: jnp.maximum(
143+ firm.S - D, 0.0
144+ ), # Reorder to S, then subtract demand
145+ lambda: jnp.maximum(x - D, 0.0), # Just subtract demand
146+ )
147+
148+ return new_x, new_x
149+
150+ # Use scan to iterate through time steps
151+ final_x, X_path = jax.lax.scan(update_step, x_init, random_keys)
152+
153+ # Prepend initial value
154+ X = jnp.concatenate([jnp.array([x_init]), X_path])
155+
156+ return X
157+ ```
113158
114- X = np.empty(sim_length)
115- X[0] = x_init
159+ ``` {code-cell} ipython3
160+ firm = Firm(s=10, S=100, μ=1.0, σ=0.5)
161+ ```
116162
117- for t in range(sim_length-1):
118- X[t+1] = self.update(X[t])
119- return X
163+ ``` {code-cell} ipython3
164+ sim_length = 100
165+ x_init = 50
166+ keys = random.split(random.PRNGKey(21), sim_length - 1)
167+ X = sim_inventory_path(firm, x_init, keys)
120168```
121169
122170Let's run a first simulation, of a single path:
123171
124172``` {code-cell} ipython3
125- firm = Firm()
126-
127173s, S = firm.s, firm.S
128- sim_length = 100
129- x_init = 50
130-
131- X = firm.sim_inventory_path(x_init, sim_length)
132174
133175fig, ax = plt.subplots()
134- bbox = (0., 1.02, 1., .102)
135- legend_args = {'ncol': 3,
136- 'bbox_to_anchor': bbox,
137- 'loc': 3,
138- 'mode': 'expand'}
176+ bbox = (0.0, 1.02, 1.0, 0.102)
177+ legend_args = {"ncol": 3, "bbox_to_anchor": bbox, "loc": 3, "mode": "expand"}
139178
140179ax.plot(X, label="inventory")
141- ax.plot(np .full(sim_length, s), ' k--' , label="$s$")
142- ax.plot(np .full(sim_length, S), 'k-' , label="$S$")
143- ax.set_ylim(0, S+ 10)
180+ ax.plot(jnp .full(sim_length, s), " k--" , label="$s$")
181+ ax.plot(jnp .full(sim_length, S), "k-" , label="$S$")
182+ ax.set_ylim(0, S + 10)
144183ax.set_xlabel("time")
145184ax.legend(**legend_args)
146185
@@ -151,17 +190,18 @@ Now let's simulate multiple paths in order to build a more complete picture of
151190the probabilities of different outcomes:
152191
153192``` {code-cell} ipython3
154- sim_length= 200
193+ sim_length = 200
155194fig, ax = plt.subplots()
156195
157- ax.plot(np .full(sim_length, s), ' k--' , label="$s$")
158- ax.plot(np .full(sim_length, S), 'k-' , label="$S$")
159- ax.set_ylim(0, S+ 10)
196+ ax.plot(jnp .full(sim_length, s), " k--" , label="$s$")
197+ ax.plot(jnp .full(sim_length, S), "k-" , label="$S$")
198+ ax.set_ylim(0, S + 10)
160199ax.legend(**legend_args)
161200
162201for i in range(400):
163- X = firm.sim_inventory_path(x_init, sim_length)
164- ax.plot(X, 'b', alpha=0.2, lw=0.5)
202+ keys = random.split(random.PRNGKey(i), sim_length - 1)
203+ X = sim_inventory_path(firm, x_init, keys)
204+ ax.plot(X, "b", alpha=0.2, lw=0.5)
165205
166206plt.show()
167207```
@@ -192,27 +232,30 @@ for ax in axes:
192232ax = axes[0]
193233
194234ax.set_ylim(ymin, ymax)
195- ax.set_ylabel(' $X_t$' , fontsize=16)
235+ ax.set_ylabel(" $X_t$" , fontsize=16)
196236ax.vlines((T,), -1.5, 1.5)
197237
198238ax.set_xticks((T,))
199- ax.set_xticklabels((r' $T$' ,))
239+ ax.set_xticklabels((r" $T$" ,))
200240
201- sample = np.empty(M)
241+ sample = []
202242for m in range(M):
203- X = firm.sim_inventory_path(x_init, 2 * T)
204- ax.plot(X, 'b-', lw=1, alpha=0.5)
205- ax.plot((T,), (X[T+1],), 'ko', alpha=0.5)
206- sample[m] = X[T+1]
243+ keys = random.split(random.PRNGKey(m), sim_length - 1)
244+ X = sim_inventory_path(firm, x_init, keys)
245+ ax.plot(X, "b-", lw=1, alpha=0.5)
246+ ax.plot((T,), (X[T + 1],), "ko", alpha=0.5)
247+ sample.append(X[T + 1])
207248
208249axes[1].set_ylim(ymin, ymax)
209250
210- axes[1].hist(sample,
211- bins=16,
212- density=True,
213- orientation='horizontal',
214- histtype='bar',
215- alpha=0.5)
251+ axes[1].hist(
252+ sample,
253+ bins=16,
254+ density=True,
255+ orientation="horizontal",
256+ histtype="bar",
257+ alpha=0.5,
258+ )
216259
217260plt.show()
218261```
@@ -225,16 +268,13 @@ M = 50_000
225268
226269fig, ax = plt.subplots()
227270
228- sample = np.empty(M)
271+ sample = []
229272for m in range(M):
230- X = firm.sim_inventory_path(x_init, T+1)
231- sample[m] = X[T]
273+ keys = random.split(random.PRNGKey(m), T)
274+ X = sim_inventory_path(firm, x_init, keys)
275+ sample.append(X[T])
232276
233- ax.hist(sample,
234- bins=36,
235- density=True,
236- histtype='bar',
237- alpha=0.75)
277+ ax.hist(sample, bins=36, density=True, histtype="bar", alpha=0.75)
238278
239279plt.show()
240280```
@@ -255,14 +295,15 @@ We will use a kernel density estimator from [scikit-learn](https://scikit-learn.
255295``` {code-cell} ipython3
256296from sklearn.neighbors import KernelDensity
257297
258- def plot_kde(sample, ax, label=''):
259298
299+ def plot_kde(sample, ax, label=""):
300+ sample = jnp.array(sample)
260301 xmin, xmax = 0.9 * min(sample), 1.1 * max(sample)
261- xgrid = np .linspace(xmin, xmax, 200)
262- kde = KernelDensity(kernel=' gaussian' ).fit(sample[:, None])
302+ xgrid = jnp .linspace(xmin, xmax, 200)
303+ kde = KernelDensity(kernel=" gaussian" ).fit(sample[:, None])
263304 log_dens = kde.score_samples(xgrid[:, None])
264305
265- ax.plot(xgrid, np .exp(log_dens), label=label)
306+ ax.plot(xgrid, jnp .exp(log_dens), label=label)
266307```
267308
268309``` {code-cell} ipython3
@@ -313,25 +354,85 @@ code efficiently.
313354This meant writing a specialized function rather than using the class above.
314355
315356``` {code-cell} ipython3
316- s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
317-
318- @jit(parallel=True)
319- def shift_firms_forward(current_inventory_levels, num_periods):
357+ # s, S, mu, sigma = firm.s, firm.S, firm.mu, firm.sigma
358+
359+ # @jit(parallel=True)
360+ # def shift_firms_forward(current_inventory_levels, num_periods):
361+
362+ # num_firms = len(current_inventory_levels)
363+ # new_inventory_levels = np.empty(num_firms)
364+
365+ # for f in prange(num_firms):
366+ # x = current_inventory_levels[f]
367+ # for t in range(num_periods):
368+ # Z = np.random.randn()
369+ # D = np.exp(mu + sigma * Z)
370+ # if x <= s:
371+ # x = max(S - D, 0)
372+ # else:
373+ # x = max(x - D, 0)
374+ # new_inventory_levels[f] = x
375+
376+ # return new_inventory_levels
377+ ```
320378
379+ ``` {code-cell} ipython3
380+ def shift_firms_forward(firm: Firm, current_inventory_levels, num_periods: int, key: random.PRNGKey):
381+ """
382+ Shift multiple firms forward by num_periods using JAX vectorization.
383+
384+ Args:
385+ firm: Firm dataclass with parameters s, S, mu, sigma
386+ current_inventory_levels: Array of current inventory levels for each firm
387+ num_periods: Number of periods to simulate forward
388+ key: JAX random key
389+
390+ Returns:
391+ Array of new inventory levels after num_periods
392+ """
393+
394+ def simulate_single_firm(x_init, firm_key):
395+ """
396+ Simulate a single firm forward by num_periods.
397+
398+ Args:
399+ x_init: Initial inventory level for this firm
400+ firm_key: Random key for this firm
401+
402+ Returns:
403+ Final inventory level after num_periods
404+ """
405+
406+ def update_step(x, period_key):
407+ """Single period update step."""
408+ Z = random.normal(period_key)
409+ D = jnp.exp(firm.mu + firm.sigma * Z)
410+
411+ new_x = jax.lax.cond(
412+ x <= firm.s,
413+ lambda: jnp.maximum(firm.S - D, 0.0),
414+ lambda: jnp.maximum(x - D, 0.0)
415+ )
416+ return new_x, None # Return None for scan accumulator (we don't need it)
417+
418+ # Generate keys for each period
419+ period_keys = random.split(firm_key, num_periods)
420+
421+ # Simulate forward num_periods
422+ final_x, _ = jax.lax.scan(update_step, x_init, period_keys)
423+
424+ return final_x
425+
426+ # Generate independent random keys for each firm
321427 num_firms = len(current_inventory_levels)
322- new_inventory_levels = np.empty(num_firms)
323-
324- for f in prange(num_firms):
325- x = current_inventory_levels[f]
326- for t in range(num_periods):
327- Z = np.random.randn()
328- D = np.exp(mu + sigma * Z)
329- if x <= s:
330- x = max(S - D, 0)
331- else:
332- x = max(x - D, 0)
333- new_inventory_levels[f] = x
334-
428+ firm_keys = random.split(key, num_firms)
429+
430+ # Vectorize over all firms using vmap
431+ vectorized_simulate = jax.vmap(simulate_single_firm, in_axes=(0, 0))
432+
433+ # Run simulation for all firms in parallel
434+ new_inventory_levels = vectorized_simulate(current_inventory_levels, firm_keys)
435+
335436 return new_inventory_levels
336437```
337438
@@ -395,6 +496,7 @@ specialized function rather than using the class above.
395496We will also use parallelization across firms.
396497
397498``` {code-cell} ipython3
499+ # TODO: Update this to JAX
398500@jit(parallel=True)
399501def compute_freq(sim_length=50, x_init=70, num_firms=1_000_000):
400502
0 commit comments