Skip to content

Commit c677d2c

Browse files
committed
feat: add llama3.2 support and experiments (#2)
* Improved compilation times; * Added support for model loading, tokenizers, and complex rope types; * Added the llama3.2 implementation in Tempo/JAX/Torch; * Added thunk wrappers: in-place writes and lazy reads, which provide major speed-up by eliminating redundant copies; * Added initial numpy backend; * Added thunk codegeneration ability, dropping the need for DL backends to trace thunks themselves; * Misc bugfixes and refactorings.
1 parent dc46762 commit c677d2c

File tree

239 files changed

+13279
-6974
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

239 files changed

+13279
-6974
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ test_run/
2929
debug_run/
3030
debug_runs/
3131
results/
32+
results_profile/
3233
examples/experiments/attn_microbench/results/
3334
data/
3435
!tempo/api/data/

.pre-commit-config.yaml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1-
default_stages: [ "commit", "commit-msg", "push" ]
1+
default_stages: ["pre-commit", "commit-msg", "pre-push"]
22
default_language_version:
33
python: python3
44

55
exclude: ^examples/llama/
66

77
repos:
8+
- repo: https://github.com/asottile/pyupgrade
9+
rev: v3.20.0
10+
hooks:
11+
- id: pyupgrade
12+
args: [--py310-plus]
13+
814
- repo: https://github.com/astral-sh/ruff-pre-commit
915
# Ruff version.
1016
rev: v0.12.1
1117
hooks:
1218
# Run the linter.
1319
- id: ruff
14-
args: [ --fix ]
20+
args: [--fix]
1521
files: ^tempo/
1622
# Run the formatter.
1723
- id: ruff-format
@@ -29,10 +35,10 @@ repos:
2935
name: "Mixed line ending fixer"
3036
- id: check-yaml
3137
name: "Yaml checker"
32-
args: [ '--unsafe' ]
38+
args: ["--unsafe"]
3339
- id: trailing-whitespace
3440
name: "Trailing whitespace fixer"
35-
args: ['--markdown-linebreak-ext=md']
41+
args: ["--markdown-linebreak-ext=md"]
3642

3743
- repo: https://github.com/pre-commit/mirrors-mypy
3844
rev: v1.16.1
@@ -41,15 +47,13 @@ repos:
4147
name: "Static type checker"
4248
files: tempo/.*\.py$
4349

44-
4550
- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook
4651
rev: v9.22.0
4752
hooks:
4853
- id: commitlint
4954
name: "Commit linter"
50-
stages: [ commit-msg ]
51-
additional_dependencies: [ '@commitlint/config-conventional' ]
52-
55+
stages: [commit-msg]
56+
additional_dependencies: ["@commitlint/config-conventional"]
5357

5458
- repo: https://github.com/kynan/nbstripout
5559
rev: 0.8.1

repro/README.md

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ repro # Package containing all reproducib
6060
│  
6161
├── expected_results/ # PNG examples of the expected plots and speedup analysis
6262
│  
63-
├── sec7_2_lm_decode/ # Scripts for running and plotting Section 7.2's experiments
63+
├── sec7_2_[lm/llama32]_decode/ # Scripts for running and plotting Section 7.2's experiments
6464
│ │
65-
│ ├── impls/ # Implementations of GPT2's architecture in JAX/Torch/Tempo
65+
│ ├── impls/ # Implementations of [GPT2/Llamas]'s architecture in JAX/Torch/Tempo
6666
│ ├── plot/ # Plotting scripts for Section 7.2
6767
│ │ ├── plot_gpt2_time_per_token.py # Script to plot Figure 9 and 10
6868
│ │ ├── plot_block_size.py # Script to plot Figure 11
@@ -138,12 +138,12 @@ We have aimed to make this process as simple as possible:
138138

139139
git clone https://github.com/lsds/Tempo/ tempo
140140
cd tempo
141-
chmod +x repro/build_run_container.sh
141+
chmod +x repro/build_run_container.sh [--llama32]
142142

143143
./repro/build_run_container.sh
144144

145145
# Now in container
146-
chmod +x repro/run_all_exprs_and_plot.sh
146+
chmod +x repro/run_all_exprs_and_plot.sh [--llama32]
147147
./repro/run_all_exprs_and_plot.sh
148148

149149
# Before exiting the container, in another shell, copy results out of container
@@ -156,6 +156,13 @@ ssh -4 <HOST> "tar -c -C /home/<USER> /path/to/plots | xz -c" | xz -d | tar -x
156156

157157
```
158158

159+
## Llama-3.2 Experiments
160+
161+
The reproducer must first obtain a copy of the model by:
162+
1. Requesting model access from [huggingface](https://huggingface.co/meta-llama/Llama-3.2-3B)
163+
2. Running 'llama model download --source huggingface --model-id meta-llama/Llama-3.2-3B' to download a checkpoint into ~/.llama/checkpoints
164+
3. Then follow the previous section, passing --llama32 to the bash scripts invoked.
165+
159166
## Working with LaunchLib
160167

161168
We developed a tiny library for parallelizing experiments across gpus.
@@ -249,6 +256,7 @@ changed, and thus, some results have changed, often for the better.
249256
We have attempted to disable certain optimizations, where needed, in order to more closely
250257
match the original results.
251258

259+
252260
### Section 7.2 - GPT-2 Decoding
253261

254262
**Figure 9 - Mean Time per Token with Causal Attention**
@@ -289,6 +297,40 @@ the best tile size for batch size of 64 has shifted to 1024 (instead of 512).
289297

290298
Results match up exactly with original submission. Tempo's circular tensor store uses a single static allocation for windowed attention. Causal attention is decomposed into blocks which are allocated as needed at runtime, causing the step-like behaviour observed.
291299

300+
### Section 7.2 - Llama3.2-3B Decoding
301+
302+
These experiments were not present in the original submission, but have been added to the final version of the paper.
303+
304+
305+
**Figure 17a - Causal attention at batch size 16**
306+
307+
![Mean time between tokens with causal attention with batch size 16](expected_results/plots/llama32/tpt/causal_16.png)
308+
309+
310+
**Figure 17b - Causal attention at batch size 4**
311+
312+
![Mean time between tokens with causal attention with batch size 4](expected_results/plots/llama32/tpt/causal_4.png)
313+
314+
**Figure 17c - Window attention at batch size 16**
315+
316+
![Mean time between tokens with window attention with batch size 16](expected_results/plots/llama32/tpt/window_16.png)
317+
318+
**For completeness - Window attention at batch size 4**
319+
320+
![Mean time between tokens with window attention with batch size 4](expected_results/plots/llama32/tpt/window_4.png)
321+
322+
**Figure 18 - Block size microbenchmark**
323+
324+
![Block-size microbenchmark](expected_results/plots/llama32/block_size/block.png)
325+
326+
**Figure 19 - Runtime Memory consumption**
327+
328+
![Runtime memory](expected_results/plots/llama32/mem/runtime_mem.png)
329+
330+
**Figure 24 - Compilation time scaling**
331+
332+
![Runtime memory](expected_results/plots/llama32/compilation/compilation_breakdown_multiple.png)
333+
292334
### Section 7.3 - RL Training (PPO)
293335

294336
**Figure 13 - Small to Medium Scale PPO**

repro/build_run_container.sh

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,48 @@
11
#! /bin/bash
22

3+
# Default to not mounting llama volume
4+
MOUNT_LLAMA=false
5+
6+
# Parse command line arguments
7+
while [[ $# -gt 0 ]]; do
8+
case $1 in
9+
--pull)
10+
if ! git diff-index --quiet HEAD --; then
11+
echo "Stopping due to uncommitted changes which would prevent pulling. Please commit or stash your changes before running this script."
12+
exit 1
13+
fi
14+
git pull
15+
shift
16+
;;
17+
--llama32)
18+
MOUNT_LLAMA=true
19+
shift
20+
;;
21+
*)
22+
echo "Unknown option $1"
23+
echo "Usage: $0 [--pull] [--llama32]"
24+
echo " --pull: Pull latest changes from git before building"
25+
echo " --llama32: Mount ~/.llama volume for llama32 experiments"
26+
exit 1
27+
;;
28+
esac
29+
done
30+
331
# Make sure we are in the repo root
432
git_repo_root=$(git rev-parse --show-toplevel)
533
pushd $git_repo_root
634

735
# Trap to ensure popd is called on exit
836
trap 'popd' EXIT
937

10-
# Check for uncommitted changes and pull if --pull is passed
11-
if [[ "$1" == "--pull" ]]; then
12-
if ! git diff-index --quiet HEAD --; then
13-
echo "Stopping due to uncommitted changes which would prevent pulling. Please commit or stash your changes before running this script."
14-
exit 1
15-
fi
16-
git pull
17-
fi
18-
1938
# Build the container
2039
DOCKER_BUILDKIT=1 docker build -f docker/gpu.dockerfile -t tempo-gpu .
2140

2241
# Run the container
23-
docker run --name tempo-repro --gpus 'all' --ipc=host --ulimit memlock=-1:-1 -it --rm tempo-gpu bash
42+
if [ "$MOUNT_LLAMA" = true ]; then
43+
echo "Mounting ~/.llama volume for llama32 experiments..."
44+
docker run --name tempo-repro -v ~/.llama:/home/tempo/.llama --gpus 'all' --ipc=host --ulimit memlock=-1:-1 -it --rm tempo-gpu bash
45+
else
46+
echo "Running container without llama volume mount..."
47+
docker run --name tempo-repro --gpus 'all' --ipc=host --ulimit memlock=-1:-1 -it --rm tempo-gpu bash
48+
fi

repro/data_loading.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from collections.abc import Callable
12
from pathlib import Path
2-
from typing import Any, Callable, Dict, List, Tuple, Union
3+
from typing import Any
34

45
import pandas as pd
56

@@ -15,7 +16,7 @@
1516
"""
1617

1718

18-
def read_csv(path: str) -> Union[pd.DataFrame, None]:
19+
def read_csv(path: str) -> pd.DataFrame | None:
1920
try:
2021
return pd.read_csv(path)
2122
except Exception:
@@ -28,7 +29,7 @@ def parse_error_file(error_file_path: str) -> str:
2829
Returns "OOM" if the error contains memory-related keywords, "MISSING" otherwise.
2930
"""
3031
try:
31-
with open(error_file_path, "r") as f:
32+
with open(error_file_path) as f:
3233
error_content = f.read().lower()
3334

3435
# Keywords that indicate out-of-memory errors
@@ -57,7 +58,7 @@ def parse_error_file(error_file_path: str) -> str:
5758
return "MISSING"
5859

5960

60-
def get_gpu_id_from_run_data(run_data: Dict[str, Any]) -> int:
61+
def get_gpu_id_from_run_data(run_data: dict[str, Any]) -> int:
6162
if run_data["monitor"] is not None:
6263
gpu_mem_col = [col for col in run_data["monitor"].columns if "gpu" in col and "mem" in col]
6364
if gpu_mem_col:
@@ -66,8 +67,8 @@ def get_gpu_id_from_run_data(run_data: Dict[str, Any]) -> int:
6667

6768

6869
def get_single_run_data(
69-
path: str, params: Dict[str, Any], name_function: Callable[[str, Dict[str, Any]], str]
70-
) -> Dict[str, Any]:
70+
path: str, params: dict[str, Any], name_function: Callable[[str, dict[str, Any]], str]
71+
) -> dict[str, Any]:
7172
# Generate expected experiment name
7273
expected_name, experiment_path = name_function(path, params)
7374

@@ -102,12 +103,12 @@ def get_single_run_data(
102103

103104
def load_sweep_data(
104105
base_path: str,
105-
base_params: Dict[str, Any],
106-
sweeps: Dict[str, List[Any]],
107-
systems: List[str],
108-
name_function: Callable[[Dict[str, Any]], str],
106+
base_params: dict[str, Any],
107+
sweeps: dict[str, list[Any]],
108+
systems: list[str],
109+
name_function: Callable[[dict[str, Any]], str],
109110
caching_allocators: bool = True,
110-
) -> Dict[str, Dict[Any, Dict[str, Dict[str, Any]]]]:
111+
) -> dict[str, dict[Any, dict[str, dict[str, Any]]]]:
111112
"""Load data from the experiment results using the naming scheme from shared.py"""
112113
# Access the small_to_med_scale subpath
113114
path = Path(base_path)
@@ -138,10 +139,10 @@ def load_sweep_data(
138139

139140

140141
def get_sweep_df(
141-
data: Dict[str, Dict[Any, Dict[str, Dict[str, Any]]]],
142-
sweeps: Dict[str, List[Any]],
142+
data: dict[str, dict[Any, dict[str, dict[str, Any]]]],
143+
sweeps: dict[str, list[Any]],
143144
sweep_key: str,
144-
systems: List[str],
145+
systems: list[str],
145146
) -> pd.DataFrame:
146147
data_list = []
147148
for sweep_value in sweeps[sweep_key]:
@@ -155,7 +156,7 @@ def get_sweep_df(
155156

156157

157158
def has_error(
158-
data: Dict[str, Dict[Any, Dict[str, Dict[str, Any]]]],
159+
data: dict[str, dict[Any, dict[str, dict[str, Any]]]],
159160
framework: str,
160161
sweep_key: str,
161162
sweep_value,
@@ -174,7 +175,7 @@ def has_error(
174175

175176

176177
def get_error_type(
177-
data: Dict[str, Dict[Any, Dict[str, Dict[str, Any]]]],
178+
data: dict[str, dict[Any, dict[str, dict[str, Any]]]],
178179
framework: str,
179180
sweep_key: str,
180181
sweep_value,
@@ -184,13 +185,13 @@ def get_error_type(
184185

185186

186187
def get_normalized_dfs(
187-
data: Dict[str, Dict[Any, Dict[str, Dict[str, Any]]]],
188+
data: dict[str, dict[Any, dict[str, dict[str, Any]]]],
188189
framework: str,
189190
sweep_key: str,
190191
sweep_value: Any,
191192
iterations_from_start_to_remove: int = 1,
192193
iterations_from_end_to_remove: int = 1,
193-
) -> Tuple[pd.DataFrame, pd.DataFrame]:
194+
) -> tuple[pd.DataFrame, pd.DataFrame]:
194195
"""Get normalized dataframes for a specific framework and sweep value"""
195196
run = data[sweep_key][sweep_value][framework]
196197
df_monitor = run["monitor"]
@@ -279,14 +280,22 @@ def compute_ratios(df: pd.DataFrame) -> pd.DataFrame:
279280

280281

281282
def build_aggregate_metric_df(
282-
data: Dict[str, Dict[Any, Dict[str, Dict[str, Any]]]],
283+
data: dict[str, dict[Any, dict[str, dict[str, Any]]]],
283284
sweep_key: str,
284285
sweep_value: Any,
285286
sys: str,
286-
) -> Dict[str, Any]:
287+
) -> dict[str, Any]:
287288
error = has_error(data, sys, sweep_key, sweep_value)
288289
error_type = get_error_type(data, sys, sweep_key, sweep_value)
289290

291+
if not error:
292+
df_monitor, df_log = get_normalized_dfs(data, sys, sweep_key, sweep_value, 0)
293+
294+
mem_dict = {
295+
"gpu_mem_mean": df_monitor["gpu_mem_util"].mean() if not error else 0,
296+
"gpu_mem_median": df_monitor["gpu_mem_util"].median() if not error else 0,
297+
"gpu_mem_peak": df_monitor["gpu_mem_util"].max() if not error else 0,
298+
}
290299
if not error:
291300
df_monitor, df_log = get_normalized_dfs(data, sys, sweep_key, sweep_value)
292301

@@ -297,9 +306,9 @@ def build_aggregate_metric_df(
297306
"framework": sys,
298307
"iter_mean": df_log["elapsed_sec"].diff().mean() if not error else 0,
299308
"iter_std": df_log["elapsed_sec"].diff().std() if not error else 0,
300-
"gpu_mem_mean": df_monitor["gpu_mem_util"].mean() if not error else 0,
301-
"gpu_mem_median": df_monitor["gpu_mem_util"].median() if not error else 0,
302-
"gpu_mem_peak": df_monitor["gpu_mem_util"].max() if not error else 0,
309+
"gpu_mem_mean": mem_dict["gpu_mem_mean"],
310+
"gpu_mem_median": mem_dict["gpu_mem_median"],
311+
"gpu_mem_peak": mem_dict["gpu_mem_peak"],
303312
"gpu_util_mean": df_monitor["gpu_util"].mean() if not error else 0,
304313
"gpu_util_median": df_monitor["gpu_util"].median() if not error else 0,
305314
"gpu_util_peak": df_monitor["gpu_util"].max() if not error else 0,
70 KB
Loading
86.5 KB
Loading
66.8 KB
Loading
44.8 KB
Loading
43.8 KB
Loading

0 commit comments

Comments
 (0)