Skip to content

Commit 8d60416

Browse files
committed
Add GPU selection flag and document benchmarks
1 parent 6c65c11 commit 8d60416

File tree

2 files changed

+110
-12
lines changed

2 files changed

+110
-12
lines changed

modal_app/fit_weights.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,14 @@
1515
REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git"
1616

1717

18-
@app.function(
19-
image=image,
20-
secrets=[hf_secret],
21-
memory=32768,
22-
cpu=4.0,
23-
gpu="A100-80GB",
24-
timeout=14400,
25-
)
26-
def fit_weights(branch: str = "main", epochs: int = 200) -> bytes:
18+
def _fit_weights_impl(branch: str, epochs: int) -> bytes:
19+
"""Shared implementation for weight fitting."""
2720
os.chdir("/root")
2821
subprocess.run(["git", "clone", "-b", branch, REPO_URL], check=True)
2922
os.chdir("policyengine-us-data")
3023

3124
subprocess.run(["uv", "sync", "--extra", "l0"], check=True)
3225

33-
# Download calibration inputs from HuggingFace
3426
print("Downloading calibration inputs from HuggingFace...")
3527
download_result = subprocess.run(
3628
[
@@ -51,7 +43,6 @@ def fit_weights(branch: str = "main", epochs: int = 200) -> bytes:
5143
if download_result.returncode != 0:
5244
raise RuntimeError(f"Download failed: {download_result.returncode}")
5345

54-
# Parse paths from output
5546
db_path = dataset_path = None
5647
for line in download_result.stdout.split('\n'):
5748
if line.startswith('DB:'):
@@ -90,13 +81,71 @@ def fit_weights(branch: str = "main", epochs: int = 200) -> bytes:
9081
return f.read()
9182

9283

84+
@app.function(
85+
image=image, secrets=[hf_secret], memory=32768, cpu=4.0,
86+
gpu="T4", timeout=14400,
87+
)
88+
def fit_weights_t4(branch: str = "main", epochs: int = 200) -> bytes:
89+
return _fit_weights_impl(branch, epochs)
90+
91+
92+
@app.function(
93+
image=image, secrets=[hf_secret], memory=32768, cpu=4.0,
94+
gpu="A10", timeout=14400,
95+
)
96+
def fit_weights_a10(branch: str = "main", epochs: int = 200) -> bytes:
97+
return _fit_weights_impl(branch, epochs)
98+
99+
100+
@app.function(
101+
image=image, secrets=[hf_secret], memory=32768, cpu=4.0,
102+
gpu="A100-40GB", timeout=14400,
103+
)
104+
def fit_weights_a100_40(branch: str = "main", epochs: int = 200) -> bytes:
105+
return _fit_weights_impl(branch, epochs)
106+
107+
108+
@app.function(
109+
image=image, secrets=[hf_secret], memory=32768, cpu=4.0,
110+
gpu="A100-80GB", timeout=14400,
111+
)
112+
def fit_weights_a100_80(branch: str = "main", epochs: int = 200) -> bytes:
113+
return _fit_weights_impl(branch, epochs)
114+
115+
116+
@app.function(
117+
image=image, secrets=[hf_secret], memory=32768, cpu=4.0,
118+
gpu="H100", timeout=14400,
119+
)
120+
def fit_weights_h100(branch: str = "main", epochs: int = 200) -> bytes:
121+
return _fit_weights_impl(branch, epochs)
122+
123+
124+
GPU_FUNCTIONS = {
125+
"T4": fit_weights_t4,
126+
"A10": fit_weights_a10,
127+
"A100-40GB": fit_weights_a100_40,
128+
"A100-80GB": fit_weights_a100_80,
129+
"H100": fit_weights_h100,
130+
}
131+
132+
93133
@app.local_entrypoint()
94134
def main(
95135
branch: str = "main",
96136
epochs: int = 200,
137+
gpu: str = "T4",
97138
output: str = "calibration_weights.npy"
98139
):
99-
weights_bytes = fit_weights.remote(branch=branch, epochs=epochs)
140+
if gpu not in GPU_FUNCTIONS:
141+
raise ValueError(
142+
f"Unknown GPU: {gpu}. Choose from: {list(GPU_FUNCTIONS.keys())}"
143+
)
144+
145+
print(f"Running with GPU: {gpu}, epochs: {epochs}, branch: {branch}")
146+
func = GPU_FUNCTIONS[gpu]
147+
weights_bytes = func.remote(branch=branch, epochs=epochs)
148+
100149
with open(output, 'wb') as f:
101150
f.write(weights_bytes)
102151
print(f"Weights saved to: {output}")

policyengine_us_data/datasets/cps/local_area_calibration/ADDING_CALIBRATION_TARGETS.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,52 @@ For most new targets:
239239
3. Run and verify with `MatrixTracer`
240240

241241
No code changes to `sparse_matrix_builder.py` needed unless you have special aggregation or constraint requirements.
242+
243+
## Running Weight Calibration on Modal (GPU)
244+
245+
The `fit_calibration_weights.py` script can be run on Modal with GPU acceleration using `modal_app/fit_weights.py`.
246+
247+
### Basic Usage
248+
249+
```bash
250+
# Default: T4 GPU, 200 epochs
251+
modal run modal_app/fit_weights.py --branch main --epochs 200
252+
253+
# Specify GPU type
254+
modal run modal_app/fit_weights.py --branch main --epochs 2000 --gpu A100-40GB
255+
```
256+
257+
### GPU Benchmarks (200 epochs, 2 target groups, Jan 2026)
258+
259+
| GPU | Time | Cost | Notes |
260+
|-----|------|------|-------|
261+
| T4 | 16m 4s | $0.16 | Best for small test runs |
262+
| A100-40GB | 9m 5s | $0.32 | ~44% faster |
263+
| A100-80GB | 10m 28s | $0.44 | Slower than 40GB (variance?) |
264+
265+
### Key Findings
266+
267+
1. **Memory bandwidth matters for sparse operations**: The P100 (not available on Modal) outperforms T4 by ~2x on Kaggle due to HBM2 memory (~732 GB/s) vs GDDR6 (~320 GB/s).
268+
269+
2. **Significant overhead at low epochs**: With only 200 epochs, much of the runtime is fixed overhead:
270+
- Git clone and `uv sync` (~2-3 min)
271+
- HuggingFace data download (~1 min)
272+
- Loading Microsimulation and building sparse matrix (~3-4 min, CPU-bound)
273+
274+
3. **GPU choice depends on epoch count**:
275+
- **< 500 epochs**: Use T4 (cheapest, overhead dominates)
276+
- **500-2000 epochs**: A100-40GB may break even
277+
- **> 2000 epochs**: A100 likely more cost-effective as training dominates
278+
279+
4. **Available Modal GPUs** (by memory bandwidth):
280+
- T4: 320 GB/s, $0.000164/sec
281+
- L4: 300 GB/s, $0.000222/sec
282+
- A10: 600 GB/s, $0.000306/sec
283+
- L40S: 864 GB/s, $0.000542/sec
284+
- A100-40GB: 1,555 GB/s, $0.000583/sec
285+
- A100-80GB: 2,039 GB/s, $0.000694/sec
286+
- H100: 3,350 GB/s, $0.001097/sec
287+
288+
### Output
289+
290+
Weights are saved locally to `calibration_weights.npy` (configurable via `--output` flag).

0 commit comments

Comments
 (0)