12
12
13
13
import numpy as np
14
14
import os
15
+ from PIL import Image
15
16
from python_coreml_stable_diffusion .torch2coreml import compute_psnr , get_pipeline
16
17
import time
17
18
30
31
31
32
# Signal integrity is computed based on these 4 random prompts
32
33
RANDOM_TEST_DATA = [
33
- "a photo of an astronaut riding a horse on mars" ,
34
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" ,
35
- "a photo of a dog" ,
36
- "a photo of a cat" ,
34
+ "a black and brown dog standing outside a door." ,
35
+ "a person on a motorcycle makes a turn on the track." ,
36
+ "inflatable boats sit on the arizona river, and on the bank" ,
37
+ "a white cat sitting under a white umbrella" ,
38
+ "black bear standing in a field of grass under a tree." ,
39
+ "a train that is parked on tracks and has graffiti writing on it, with a mountain range in the background." ,
40
+ "a cake inside of a pan sitting in an oven." ,
41
+ "a table with paper plates and flowers in a home" ,
37
42
]
38
43
44
+ TEST_RESOLUTION = 768
45
+
46
+ RANDOM_TEST_IMAGE_DATA = [
47
+ Image .open (
48
+ requests .get (path , stream = True ).raw ).convert ("RGB" ).resize (
49
+ (TEST_RESOLUTION , TEST_RESOLUTION ), Image .LANCZOS
50
+ ) for path in [
51
+ "http://farm1.staticflickr.com/106/298138827_19bb723252_z.jpg" ,
52
+ "http://farm4.staticflickr.com/3772/9666116202_648cd752d6_z.jpg" ,
53
+ "http://farm3.staticflickr.com/2238/2472574092_f5534bb2f7_z.jpg" ,
54
+ "http://farm1.staticflickr.com/220/475442674_47d81fdc2c_z.jpg" ,
55
+ "http://farm8.staticflickr.com/7231/7359341784_4c5358197f_z.jpg" ,
56
+ "http://farm8.staticflickr.com/7283/8737653089_d0c77b8597_z.jpg" ,
57
+ "http://farm3.staticflickr.com/2454/3989339438_2f32b76ebb_z.jpg" ,
58
+ "http://farm1.staticflickr.com/34/123005230_13051344b1_z.jpg" ,
59
+ ]]
60
+
61
+
39
62
# Copied from https://github.com/apple/coremltools/blob/7.0b1/coremltools/optimize/coreml/_quantization_passes.py#L602
40
63
from coremltools .converters .mil .mil import types
41
64
def fake_linear_quantize (val , axis = - 1 , mode = 'LINEAR' , dtype = types .int8 ):
@@ -217,14 +240,11 @@ def fake_palette_from_recipe(module, recipe):
217
240
218
241
logger .info (f"Palettized to { tot_bits / tot_numel :.2f} -bits mixed palette ({ tot_bits / 8e6 } MB) " )
219
242
220
-
221
- TEST_RESOLUTION = 768
222
-
223
243
# Globally synced RNG state
224
244
rng = torch .Generator ()
225
245
rng_state = rng .get_state ()
226
246
227
- def run_pipe (pipe , prompts ):
247
+ def run_pipe (pipe ):
228
248
if torch .backends .mps .is_available ():
229
249
device = "mps"
230
250
elif torch .cuda .is_available ():
@@ -235,18 +255,29 @@ def run_pipe(pipe, prompts):
235
255
236
256
global rng , rng_state
237
257
rng .set_state (rng_state )
238
- return np .array ([latent .cpu ().numpy () for latent in pipe .to (device )(
239
- prompt = prompts ,
258
+ kwargs = dict (
259
+ prompt = RANDOM_TEST_DATA ,
260
+ negative_prompt = ["" ] * len (RANDOM_TEST_DATA ),
240
261
num_inference_steps = 1 ,
241
262
height = TEST_RESOLUTION ,
242
263
width = TEST_RESOLUTION ,
243
264
output_type = "latent" ,
244
- generator = rng ).images ])
265
+ generator = rng
266
+ )
267
+ if "Img2Img" in pipe .__class__ .__name__ :
268
+ kwargs ["image" ] = RANDOM_TEST_IMAGE_DATA
269
+ kwargs .pop ("height" )
270
+ kwargs .pop ("width" )
271
+
272
+ # Run a single denoising step
273
+ kwargs ["num_inference_steps" ] = 4
274
+ kwargs ["strength" ] = 0.25
275
+
276
+ return np .array ([latent .cpu ().numpy () for latent in pipe .to (device )(** kwargs ).images ])
245
277
246
278
247
279
def benchmark_signal_integrity (pipe ,
248
280
candidates ,
249
- sample_input_batch ,
250
281
nbits ,
251
282
cumulative ,
252
283
in_ngroups = 1 ,
@@ -263,7 +294,7 @@ def benchmark_signal_integrity(pipe,
263
294
264
295
# If reference outputs are not provided, treat current pipe as reference
265
296
if ref_out is None :
266
- ref_out = run_pipe (pipe , sample_input_batch )
297
+ ref_out = run_pipe (pipe )
267
298
268
299
for candidate in tqdm (candidates ):
269
300
palettized = False
@@ -280,7 +311,7 @@ def benchmark_signal_integrity(pipe,
280
311
if not palettized :
281
312
raise KeyError (name )
282
313
283
- test_out = run_pipe (pipe , sample_input_batch )
314
+ test_out = run_pipe (pipe )
284
315
285
316
if not cumulative :
286
317
restore_weight (module , orig_weight )
@@ -304,11 +335,11 @@ def descending_psnr_order(results):
304
335
def simulate_quant_fn (ref_pipe , quantization_to_simulate ):
305
336
simulated_pipe = deepcopy (ref_pipe .to ('cpu' ))
306
337
quantization_to_simulate (simulated_pipe .unet )
307
- simulated_out = run_pipe (simulated_pipe , RANDOM_TEST_DATA )
338
+ simulated_out = run_pipe (simulated_pipe )
308
339
del simulated_pipe
309
340
gc .collect ()
310
341
311
- ref_out = run_pipe (ref_pipe , RANDOM_TEST_DATA )
342
+ ref_out = run_pipe (ref_pipe )
312
343
simulated_psnr = sum ([
313
344
float (f"{ compute_psnr (r ,t ):.1f} " )
314
345
for r ,t in zip (ref_out , simulated_out )
@@ -414,7 +445,7 @@ def main(args):
414
445
logger .info ("Done." )
415
446
416
447
# Cache reference outputs
417
- ref_out = run_pipe (pipe , RANDOM_TEST_DATA )
448
+ ref_out = run_pipe (pipe )
418
449
419
450
# Bookkeeping
420
451
os .makedirs (args .o , exist_ok = True )
@@ -442,7 +473,6 @@ def main(args):
442
473
results ['single_layer' ][str (nbits )] = benchmark_signal_integrity (
443
474
pipe ,
444
475
candidates ,
445
- RANDOM_TEST_DATA ,
446
476
nbits ,
447
477
cumulative = False ,
448
478
ref_out = ref_out ,
@@ -457,7 +487,6 @@ def main(args):
457
487
results ['cumulative' ][str (nbits )] = benchmark_signal_integrity (
458
488
deepcopy (pipe ),
459
489
sorted_candidates ,
460
- RANDOM_TEST_DATA ,
461
490
nbits ,
462
491
cumulative = True ,
463
492
ref_out = ref_out ,
0 commit comments