Skip to content

Commit 8c8ad93

Browse files
Merge pull request #4635 from mezotaken/master
CI tests with github-actions and some improvements to testing
2 parents b24aed0 + 14dfede commit 8c8ad93

File tree

11 files changed

+125
-44
lines changed

11 files changed

+125
-44
lines changed

.github/workflows/run_tests.yaml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: Run basic features tests on CPU with empty SD model
2+
3+
on:
4+
- push
5+
- pull_request
6+
7+
jobs:
8+
test:
9+
runs-on: ubuntu-latest
10+
steps:
11+
- name: Checkout Code
12+
uses: actions/checkout@v3
13+
- name: Set up Python 3.10
14+
uses: actions/setup-python@v4
15+
with:
16+
python-version: 3.10.6
17+
- uses: actions/cache@v3
18+
with:
19+
path: ~/.cache/pip
20+
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
21+
restore-keys: ${{ runner.os }}-pip-
22+
- name: Run tests
23+
run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
24+
- name: Upload main app stdout-stderr
25+
uses: actions/upload-artifact@v3
26+
if: always()
27+
with:
28+
name: stdout-stderr
29+
path: |
30+
test/stdout.txt
31+
test/stderr.txt

launch.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,19 @@ def extract_arg(args, name):
1717
return [x for x in args if x != name], name in args
1818

1919

20+
def extract_opt(args, name):
21+
opt = None
22+
is_present = False
23+
if name in args:
24+
is_present = True
25+
idx = args.index(name)
26+
del args[idx]
27+
if idx < len(args) and args[idx][0] != "-":
28+
opt = args[idx]
29+
del args[idx]
30+
return args, is_present, opt
31+
32+
2033
def run(command, desc=None, errdesc=None, custom_env=None):
2134
if desc is not None:
2235
print(desc)
@@ -151,12 +164,11 @@ def prepare_enviroment():
151164
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
152165

153166
sys.argv += shlex.split(commandline_args)
154-
test_argv = [x for x in sys.argv if x != '--tests']
155167

156168
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
157169
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
158170
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
159-
sys.argv, run_tests = extract_arg(sys.argv, '--tests')
171+
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
160172
xformers = '--xformers' in sys.argv
161173
ngrok = '--ngrok' in sys.argv
162174

@@ -221,24 +233,30 @@ def prepare_enviroment():
221233
exit(0)
222234

223235
if run_tests:
224-
tests(test_argv)
225-
exit(0)
236+
exitcode = tests(test_dir)
237+
exit(exitcode)
226238

227239

228-
def tests(argv):
229-
if "--api" not in argv:
230-
argv.append("--api")
240+
def tests(test_dir):
241+
if "--api" not in sys.argv:
242+
sys.argv.append("--api")
243+
if "--ckpt" not in sys.argv:
244+
sys.argv.append("--ckpt")
245+
sys.argv.append("./test/test_files/empty.pt")
246+
if "--skip-torch-cuda-test" not in sys.argv:
247+
sys.argv.append("--skip-torch-cuda-test")
231248

232-
print(f"Launching Web UI in another process for testing with arguments: {' '.join(argv[1:])}")
249+
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
233250

234251
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
235-
proc = subprocess.Popen([sys.executable, *argv], stdout=stdout, stderr=stderr)
252+
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
236253

237254
import test.server_poll
238-
test.server_poll.run_tests()
255+
exitcode = test.server_poll.run_tests(proc, test_dir)
239256

240257
print(f"Stopping Web UI process with id {proc.pid}")
241258
proc.kill()
259+
return exitcode
242260

243261

244262
def start():

test/advanced_features/__init__.py

Whitespace-only changes.

test/extras_test.py renamed to test/advanced_features/extras_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ def setUp(self):
1111
"codeformer_visibility": 0,
1212
"codeformer_weight": 0,
1313
"upscaling_resize": 2,
14-
"upscaling_resize_w": 512,
15-
"upscaling_resize_h": 512,
14+
"upscaling_resize_w": 128,
15+
"upscaling_resize_h": 128,
1616
"upscaling_crop": True,
1717
"upscaler_1": "None",
1818
"upscaler_2": "None",
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import unittest
2+
import requests
3+
4+
5+
class TestTxt2ImgWorking(unittest.TestCase):
6+
def setUp(self):
7+
self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
8+
self.simple_txt2img = {
9+
"enable_hr": False,
10+
"denoising_strength": 0,
11+
"firstphase_width": 0,
12+
"firstphase_height": 0,
13+
"prompt": "example prompt",
14+
"styles": [],
15+
"seed": -1,
16+
"subseed": -1,
17+
"subseed_strength": 0,
18+
"seed_resize_from_h": -1,
19+
"seed_resize_from_w": -1,
20+
"batch_size": 1,
21+
"n_iter": 1,
22+
"steps": 3,
23+
"cfg_scale": 7,
24+
"width": 64,
25+
"height": 64,
26+
"restore_faces": False,
27+
"tiling": False,
28+
"negative_prompt": "",
29+
"eta": 0,
30+
"s_churn": 0,
31+
"s_tmax": 0,
32+
"s_tmin": 0,
33+
"s_noise": 1,
34+
"sampler_index": "Euler a"
35+
}
36+
37+
def test_txt2img_with_restore_faces_performed(self):
38+
self.simple_txt2img["restore_faces"] = True
39+
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
40+
41+
42+
class TestTxt2ImgCorrectness(unittest.TestCase):
43+
pass
44+
45+
46+
if __name__ == "__main__":
47+
unittest.main()

test/basic_features/__init__.py

Whitespace-only changes.

test/img2img_test.py renamed to test/basic_features/img2img_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,5 @@ def test_inpainting_masked_performed(self):
5151
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
5252

5353

54-
class TestImg2ImgCorrectness(unittest.TestCase):
55-
pass
56-
57-
5854
if __name__ == "__main__":
5955
unittest.main()

test/txt2img_test.py renamed to test/basic_features/txt2img_test.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,26 +49,20 @@ def test_txt2img_with_hrfix_performed(self):
4949
self.simple_txt2img["enable_hr"] = True
5050
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
5151

52-
def test_txt2img_with_restore_faces_performed(self):
53-
self.simple_txt2img["restore_faces"] = True
54-
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
55-
56-
def test_txt2img_with_tiling_faces_performed(self):
52+
def test_txt2img_with_tiling_performed(self):
5753
self.simple_txt2img["tiling"] = True
5854
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
5955

6056
def test_txt2img_with_vanilla_sampler_performed(self):
6157
self.simple_txt2img["sampler_index"] = "PLMS"
6258
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
59+
self.simple_txt2img["sampler_index"] = "DDIM"
60+
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
6361

6462
def test_txt2img_multiple_batches_performed(self):
6563
self.simple_txt2img["n_iter"] = 2
6664
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
6765

6866

69-
class TestTxt2ImgCorrectness(unittest.TestCase):
70-
pass
71-
72-
7367
if __name__ == "__main__":
7468
unittest.main()

test/utils_test.py renamed to test/basic_features/utils_test.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,6 @@ def setUp(self):
1818
def test_options_get(self):
1919
self.assertEqual(requests.get(self.url_options).status_code, 200)
2020

21-
def test_options_write(self):
22-
response = requests.get(self.url_options)
23-
self.assertEqual(response.status_code, 200)
24-
25-
pre_value = response.json()["send_seed"]
26-
27-
self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
28-
29-
response = requests.get(self.url_options)
30-
self.assertEqual(response.status_code, 200)
31-
self.assertEqual(response.json()["send_seed"], not pre_value)
32-
33-
requests.post(self.url_options, json={"send_seed": pre_value})
34-
3521
def test_cmd_flags(self):
3622
self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
3723

@@ -60,4 +46,8 @@ def test_artist_categories(self):
6046
self.assertEqual(requests.get(self.url_artist_categories).status_code, 200)
6147

6248
def test_artists(self):
63-
self.assertEqual(requests.get(self.url_artists).status_code, 200)
49+
self.assertEqual(requests.get(self.url_artists).status_code, 200)
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

test/server_poll.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,22 @@
33
import time
44

55

6-
def run_tests():
6+
def run_tests(proc, test_dir):
77
timeout_threshold = 240
88
start_time = time.time()
99
while time.time()-start_time < timeout_threshold:
1010
try:
1111
requests.head("http://localhost:7860/")
1212
break
1313
except requests.exceptions.ConnectionError:
14-
pass
15-
if time.time()-start_time < timeout_threshold:
16-
suite = unittest.TestLoader().discover('', pattern='*_test.py')
14+
if proc.poll() is not None:
15+
break
16+
if proc.poll() is None:
17+
if test_dir is None:
18+
test_dir = ""
19+
suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test")
1720
result = unittest.TextTestRunner(verbosity=2).run(suite)
21+
return len(result.failures) + len(result.errors)
1822
else:
1923
print("Launch unsuccessful")
24+
return 1

0 commit comments

Comments
 (0)