-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathbenchmark.py
More file actions
108 lines (83 loc) · 3.08 KB
/
benchmark.py
File metadata and controls
108 lines (83 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import pytest
import torch
from world_engine import WorldEngine
# TODO
# - benchmark encode img
# - benchmark encode prompt
def version_with_commit(pkg):
import json
from importlib.metadata import distribution
dist = distribution(pkg.__name__.split('.')[0])
version = dist.version
try:
data = dist.read_text("direct_url.json")
commit = (data and json.loads(data).get("vcs_info", {}).get("commit_id"))
except (FileNotFoundError, json.JSONDecodeError, TypeError):
commit = None
return f"{version} @ {commit[:7]}" if commit else version
@pytest.fixture(scope="session", autouse=True)
def print_env_info():
import platform
import world_engine as world_engine_pkg
print(
"\n=== Environment ===\n"
f"torch: {torch.__version__}\n"
f"torch.cuda: {torch.version.cuda}\n"
f"world_engine: {version_with_commit(world_engine_pkg)}\n\n"
"=== Hardware ===\n"
f"OS: {platform.system()} {platform.release()} ({platform.machine()})\n"
f"CPU: {platform.processor() or 'unknown'}"
)
if torch.cuda.is_available():
idx = torch.cuda.current_device()
props = torch.cuda.get_device_properties(idx)
print(
f"GPU: {props.name}\n"
f" capability {props.major}.{props.minor}\n"
f" total memory: {props.total_memory / 1e9:.1f} GB"
)
else:
print("GPU: none (CUDA not available)")
def get_warm_engine(model_uri, model_overrides=None):
model_config_overrides = {"ae_uri": "OpenWorldLabs/owl_vae_f16_c16_distill_v0_nogan"}
model_config_overrides.update(model_overrides or {})
engine = WorldEngine(
model_uri,
model_config_overrides=model_config_overrides,
device="cuda",
load_weights=False
)
# global warmup
for _ in range(3):
engine.gen_frame()
return engine
@pytest.fixture(scope="session")
def engine(model_uri="Overworld/Waypoint-1-Small"):
return get_warm_engine(model_uri)
@pytest.fixture(scope="session")
def last_latent(engine):
return engine.gen_frame(return_img=False).detach()
def test_img_decoder_only(benchmark, engine, last_latent):
def run():
with torch.amp.autocast("cuda", torch.bfloat16):
engine.vae.decode(last_latent)
torch.cuda.synchronize()
benchmark(run)
MODEL_OVERRIDES = [None]
@pytest.mark.parametrize("dit_only", [True])
@pytest.mark.parametrize("n_frames", [256])
@pytest.mark.parametrize(
"model_overrides", MODEL_OVERRIDES,
ids=lambda d: (",".join(f"{k}={v}" for k, v in d.items()) or "") if d else ""
)
def test_ar_rollout(benchmark, dit_only, n_frames, model_overrides):
engine = get_warm_engine("Overworld/Waypoint-1-Small", model_overrides=model_overrides)
def setup():
engine.reset()
engine.gen_frame(return_img=not dit_only)
torch.cuda.synchronize()
def target():
for _ in range(n_frames):
engine.gen_frame(return_img=not dit_only)
torch.cuda.synchronize()
benchmark.pedantic(target, setup=setup, rounds=20)