Skip to content

Commit cf28a4c

Browse files
CopilotmmckyHumphreyYang
authored
Replace np.sum(a * b) with a @ b for better performance and accuracy (#542)
* Initial plan * Replace np.sum(a * b) with a @ b in linear_algebra, career, lake_model, and mccall_model_with_separation Co-authored-by: mmcky <[email protected]> * Additional replacements: mix_model, kalman, and rand_resp files Co-authored-by: mmcky <[email protected]> * Update lectures/lake_model.md * Update lectures/linear_algebra.md * revert change in rand_resp --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: mmcky <[email protected]> Co-authored-by: Humphrey Yang <[email protected]> Co-authored-by: Humphrey Yang <[email protected]>
1 parent a8328e1 commit cf28a4c

File tree

6 files changed

+23
-11
lines changed

6 files changed

+23
-11
lines changed

lectures/career.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ class CareerWorkerProblem:
206206
207207
self.F_probs = BetaBinomial(grid_size - 1, F_a, F_b).pdf()
208208
self.G_probs = BetaBinomial(grid_size - 1, G_a, G_b).pdf()
209-
self.F_mean = np.sum(self.θ * self.F_probs)
210-
self.G_mean = np.sum(self.ϵ * self.G_probs)
209+
self.F_mean = self.θ @ self.F_probs
210+
self.G_mean = self.ϵ @ self.G_probs
211211
212212
# Store these parameters for str and repr methods
213213
self._F_a, self._F_b = F_a, F_b

lectures/kalman.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -783,8 +783,10 @@ e2 = np.empty(T-1)
783783
784784
for t in range(1, T):
785785
kn.update(y[:,t])
786-
e1[t-1] = np.sum((x[:, t] - kn.x_hat.flatten())**2)
787-
e2[t-1] = np.sum((x[:, t] - A @ x[:, t-1])**2)
786+
diff1 = x[:, t] - kn.x_hat.flatten()
787+
diff2 = x[:, t] - A @ x[:, t-1]
788+
e1[t-1] = diff1 @ diff1
789+
e2[t-1] = diff2 @ diff2
788790
789791
fig, ax = plt.subplots(figsize=(9,6))
790792
ax.plot(range(1, T), e1, 'k-', lw=2, alpha=0.6,

lectures/lake_model.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ def _update_bellman(α, β, γ, c, σ, w_vec, p_vec, V, V_new, U):
699699
V_new[w_idx] = u(w, σ) + β * ((1 - α) * V[w_idx] + α * U)
700700
701701
U_new = u(c, σ) + β * (1 - γ) * U + \
702-
β * γ * np.sum(np.maximum(U, V) * p_vec)
702+
β * γ * (np.maximum(U, V) @ p_vec)
703703
704704
return U_new
705705
@@ -836,8 +836,8 @@ def compute_steady_state_quantities(c, τ):
836836
u, e = x
837837
838838
# Compute steady state welfare
839-
w = np.sum(V * p_vec * (w_vec - τ > w_bar)) / np.sum(p_vec * (w_vec -
840-
τ > w_bar))
839+
mask = (w_vec - τ > w_bar)
840+
w = ((V * p_vec * mask) @ np.ones_like(p_vec)) / ((p_vec * mask) @ np.ones_like(p_vec))
841841
welfare = e * w + u * U
842842
843843
return e, u, welfare

lectures/linear_algebra.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,25 @@ Continuing on from the previous example, the inner product and norm can be compu
246246
follows
247247

248248
```{code-cell} python3
249-
np.sum(x * y) # Inner product of x and y
249+
np.sum(x * y) # Inner product of x and y, method 1
250250
```
251251

252+
```{code-cell} python3
253+
x @ y # Inner product of x and y, method 2 (preferred)
254+
```
255+
256+
The `@` operator is preferred because it uses optimized BLAS libraries that implement fused multiply-add operations, providing better performance and numerical accuracy compared to the separate multiply and sum operations.
257+
252258
```{code-cell} python3
253259
np.sqrt(np.sum(x**2)) # Norm of x, take one
254260
```
255261

256262
```{code-cell} python3
257-
np.linalg.norm(x) # Norm of x, take two
263+
np.sqrt(x @ x) # Norm of x, take two (preferred)
264+
```
265+
266+
```{code-cell} python3
267+
np.linalg.norm(x) # Norm of x, take three
258268
```
259269

260270
### Span

lectures/mccall_model_with_separation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def update(model, v, d):
345345
" One update on the Bellman equations. "
346346
α, β, c, w, q = model.α, model.β, model.c, model.w, model.q
347347
v_new = u(w) + β * ((1 - α) * v + α * d)
348-
d_new = jnp.sum(jnp.maximum(v, u(c) + β * d) * q)
348+
d_new = jnp.maximum(v, u(c) + β * d) @ q
349349
return v_new, d_new
350350
351351
@jax.jit

lectures/mix_model.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def learn_x_bayesian(observations, α0, β0, grid_size=2000):
820820
post = np.exp(log_post)
821821
post /= post.sum()
822822
823-
μ_path[t + 1] = np.sum(x_grid * post)
823+
μ_path[t + 1] = x_grid @ post
824824
825825
return μ_path
826826

0 commit comments

Comments
 (0)