Skip to content

Commit f1c1ce0

Browse files
committed
Make VRAM test a pytest file
1 parent 30e2f76 commit f1c1ce0

File tree

2 files changed

+43
-50
lines changed

2 files changed

+43
-50
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
from collections import namedtuple
3+
4+
from sdkit import Context
5+
from sdkit.generate import generate_images
6+
from sdkit.models import load_model
7+
from sdkit.utils import get_device_usage, log
8+
9+
10+
def test_vram_frees_after_image():
11+
DeviceUsage = namedtuple(
12+
"DeviceUsage", ["cpu_used", "ram_used", "ram_total", "vram_used", "vram_total", "vram_peak"]
13+
)
14+
15+
c = Context()
16+
17+
log.info("Starting..")
18+
usage_start = DeviceUsage(*get_device_usage(c.device, log_info=True))
19+
20+
c.model_paths["stable-diffusion"] = "models/stable-diffusion/1.x/sd-v1-4.ckpt"
21+
load_model(c, "stable-diffusion")
22+
23+
log.info("Loaded the model..")
24+
usage_model_load = DeviceUsage(*get_device_usage(c.device, log_info=True))
25+
26+
try:
27+
images = generate_images(c, prompt="Photograph of an astronaut riding a horse")
28+
except Exception as e:
29+
log.exception(e)
30+
31+
log.info("Generated the image..")
32+
usage_after_render = DeviceUsage(*get_device_usage(c.device, log_info=True))
33+
34+
print("")
35+
log.info(
36+
f"VRAM trend: {usage_start.vram_used:.1f} (start) GiB to {usage_model_load.vram_used:.1f} GiB (before render) to {usage_after_render.vram_used:.1f} GiB (after render)"
37+
)
38+
print("")
39+
40+
max_expected_vram = usage_model_load.vram_used + 0.3
41+
assert (
42+
usage_after_render.vram_used < max_expected_vram
43+
), f"Test failed! VRAM after render was expected to be below {max_expected_vram:.1f} GiB, but was {usage_after_render.vram_used:.1f} GiB!"

tests/vram_frees_after_image_generation.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

0 commit comments

Comments
 (0)