Skip to content

Commit 5298e02

Browse files
Add experiments data location to run_experiment (#46)
Co-authored-by: Joel Schlosser <[email protected]>
1 parent a60de07 commit 5298e02

File tree

2 files changed

+48
-41
lines changed

2 files changed

+48
-41
lines changed

experiments/run_experiments.py

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import subprocess
2-
import os
3-
import math
2+
import fire
43
import itertools
4+
import functools
55

66
home = "/home/cpuhrsch"
77

@@ -30,22 +30,8 @@ def change_sam_commit(commit_name):
3030
assert result.returncode == 0
3131

3232

33-
root_cmd = ["python", "eval_combo.py",
34-
"--coco_root_dir",
35-
"experiments_data/datasets/coco2017",
36-
"--coco_slice_name",
37-
"val2017",
38-
"--sam_checkpoint_base_path",
39-
"experiments_data/checkpoints",
40-
"--sam_model_type",
41-
"vit_b",
42-
"--point_sampling_cache_dir",
43-
"experiments_data/tmp/sam_coco_mask_center_cache",
44-
"--mask_debug_out_dir",
45-
"experiments_data/tmp/sam_eval_masks_out"]
46-
47-
48-
def run_experiment(idx,
33+
def run_experiment(experiments_data,
34+
idx,
4935
sam_commit_name,
5036
model_type,
5137
batch_size,
@@ -61,6 +47,19 @@ def run_experiment(idx,
6147
profile_path=None,
6248
profile_top=False,
6349
memory_path=None):
50+
root_cmd = ["python", "eval_combo.py",
51+
"--coco_root_dir",
52+
f"{experiments_data}/datasets/coco2017",
53+
"--coco_slice_name",
54+
"val2017",
55+
"--sam_checkpoint_base_path",
56+
f"{experiments_data}/checkpoints",
57+
"--sam_model_type",
58+
"vit_b",
59+
"--point_sampling_cache_dir",
60+
f"{experiments_data}/tmp/sam_coco_mask_center_cache",
61+
"--mask_debug_out_dir",
62+
f"{experiments_data}/tmp/sam_eval_masks_out"]
6463
args = root_cmd
6564
args = args + ["--sam_model_type", model_type]
6665
args = args + ["--batch_size", str(batch_size)]
@@ -139,24 +138,32 @@ def run_traces(*args, **kwargs):
139138
result = subprocess.run(conversion_cmd, capture_output=True)
140139
assert result.returncode == 0
141140

142-
# run_traces("fp32", "default", "vit_b", 16, 32, print_header=True)
143-
# run_traces("fp16", "codesign", "vit_b", 16, 32, use_half=True)
144-
# run_traces("compile", "codesign", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
145-
# run_traces("SDPA", "sdpa-decoder", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
146-
# run_traces("Triton", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
147-
# run_traces("NT", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True)
148-
# run_traces("int8", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True, compress="dynamic_quant")
149-
# run_traces("sparse", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True, compress="sparse")
150-
151-
print_header = True
152-
for bs, model in itertools.product([1, 32], ["vit_b", "vit_h"]):
153-
# run_experiment("fp32", "default", model, bs, 32, print_header=print_header)
154-
# print_header = False
155-
# run_experiment("bf16", "codesign", model, bs, 32, use_half="bfloat16")
156-
# run_experiment("compile", "codesign", model, bs, 32, use_half="bfloat16", use_compile="max-autotune")
157-
# run_experiment("SDPA", "sdpa-decoder", model, bs, 32, use_half="bfloat16", use_compile="max-autotune")
158-
run_experiment("Triton", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune")
159-
if bs > 1:
160-
run_experiment("NT", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1))
161-
run_experiment("int8", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="dynamic_quant")
162-
run_experiment("sparse", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="sparse")
141+
def run(experiments_data=None):
142+
if experiments_data is None:
143+
experiments_data = "experiments_data"
144+
145+
# run_traces("fp32", "default", "vit_b", 16, 32, print_header=True)
146+
# run_traces("fp16", "codesign", "vit_b", 16, 32, use_half=True)
147+
# run_traces("compile", "codesign", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
148+
# run_traces("SDPA", "sdpa-decoder", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
149+
# run_traces("Triton", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune")
150+
# run_traces("NT", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True)
151+
# run_traces("int8", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True, compress="dynamic_quant")
152+
# run_traces("sparse", "local-fork", "vit_b", 16, 32, use_half=True, use_compile="max-autotune", use_nested_tensor=True, compress="sparse")
153+
154+
rexp = functools.partial(run_experiment, experiments_data)
155+
print_header = True
156+
for bs, model in itertools.product([1, 32], ["vit_b", "vit_h"]):
157+
# rexp("fp32", "default", model, bs, 32, print_header=print_header)
158+
print_header = False
159+
# rexp("bf16", "codesign", model, bs, 32, use_half="bfloat16")
160+
# rexp("compile", "codesign", model, bs, 32, use_half="bfloat16", use_compile="max-autotune")
161+
# rexp("SDPA", "sdpa-decoder", model, bs, 32, use_half="bfloat16", use_compile="max-autotune")
162+
rexp("Triton", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", capture_output=False)
163+
if bs > 1:
164+
rexp("NT", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1))
165+
rexp("int8", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="dynamic_quant")
166+
rexp("sparse", "local-fork", model, bs, 32, use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(bs > 1), compress="sparse")
167+
168+
if __name__ == '__main__':
169+
fire.Fire(run)

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
version='0.2',
88
packages=packages,
99
install_requires=[
10-
'torch>=2.2.0.dev20231015',
11-
'torchvision>=0.17.0.dev20231015',
10+
'torch>=2.2.0.dev20231019',
11+
'torchvision>=0.17.0.dev20231019',
1212
'diskcache',
1313
'pycocotools',
1414
'scipy',

0 commit comments

Comments
 (0)