Skip to content

Commit e3c1e36

Browse files
committed
SDXL refiner mixed-bit palettization pre-analysis
1 parent cf4afe1 commit e3c1e36

File tree

2 files changed

+57
-20
lines changed

2 files changed

+57
-20
lines changed

python_coreml_stable_diffusion/mixed_bit_compression_pre_analysis.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import numpy as np
1414
import os
15+
from PIL import Image
1516
from python_coreml_stable_diffusion.torch2coreml import compute_psnr, get_pipeline
1617
import time
1718

@@ -30,12 +31,34 @@
3031

3132
# Signal integrity is computed based on these 4 random prompts
3233
RANDOM_TEST_DATA = [
33-
"a photo of an astronaut riding a horse on mars",
34-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
35-
"a photo of a dog",
36-
"a photo of a cat",
34+
"a black and brown dog standing outside a door.",
35+
"a person on a motorcycle makes a turn on the track.",
36+
"inflatable boats sit on the arizona river, and on the bank",
37+
"a white cat sitting under a white umbrella",
38+
"black bear standing in a field of grass under a tree.",
39+
"a train that is parked on tracks and has graffiti writing on it, with a mountain range in the background.",
40+
"a cake inside of a pan sitting in an oven.",
41+
"a table with paper plates and flowers in a home",
3742
]
3843

44+
TEST_RESOLUTION = 768
45+
46+
RANDOM_TEST_IMAGE_DATA = [
47+
Image.open(
48+
requests.get(path, stream=True).raw).convert("RGB").resize(
49+
(TEST_RESOLUTION, TEST_RESOLUTION), Image.LANCZOS
50+
) for path in [
51+
"http://farm1.staticflickr.com/106/298138827_19bb723252_z.jpg",
52+
"http://farm4.staticflickr.com/3772/9666116202_648cd752d6_z.jpg",
53+
"http://farm3.staticflickr.com/2238/2472574092_f5534bb2f7_z.jpg",
54+
"http://farm1.staticflickr.com/220/475442674_47d81fdc2c_z.jpg",
55+
"http://farm8.staticflickr.com/7231/7359341784_4c5358197f_z.jpg",
56+
"http://farm8.staticflickr.com/7283/8737653089_d0c77b8597_z.jpg",
57+
"http://farm3.staticflickr.com/2454/3989339438_2f32b76ebb_z.jpg",
58+
"http://farm1.staticflickr.com/34/123005230_13051344b1_z.jpg",
59+
]]
60+
61+
3962
# Copied from https://github.com/apple/coremltools/blob/7.0b1/coremltools/optimize/coreml/_quantization_passes.py#L602
4063
from coremltools.converters.mil.mil import types
4164
def fake_linear_quantize(val, axis=-1, mode='LINEAR', dtype=types.int8):
@@ -217,14 +240,11 @@ def fake_palette_from_recipe(module, recipe):
217240

218241
logger.info(f"Palettized to {tot_bits/tot_numel:.2f}-bits mixed palette ({tot_bits/8e6} MB) ")
219242

220-
221-
TEST_RESOLUTION = 768
222-
223243
# Globally synced RNG state
224244
rng = torch.Generator()
225245
rng_state = rng.get_state()
226246

227-
def run_pipe(pipe, prompts):
247+
def run_pipe(pipe):
228248
if torch.backends.mps.is_available():
229249
device = "mps"
230250
elif torch.cuda.is_available():
@@ -235,18 +255,29 @@ def run_pipe(pipe, prompts):
235255

236256
global rng, rng_state
237257
rng.set_state(rng_state)
238-
return np.array([latent.cpu().numpy() for latent in pipe.to(device)(
239-
prompt=prompts,
258+
kwargs = dict(
259+
prompt=RANDOM_TEST_DATA,
260+
negative_prompt=[""] * len(RANDOM_TEST_DATA),
240261
num_inference_steps=1,
241262
height=TEST_RESOLUTION,
242263
width=TEST_RESOLUTION,
243264
output_type="latent",
244-
generator=rng).images])
265+
generator=rng
266+
)
267+
if "Img2Img" in pipe.__class__.__name__:
268+
kwargs["image"] = RANDOM_TEST_IMAGE_DATA
269+
kwargs.pop("height")
270+
kwargs.pop("width")
271+
272+
# Run a single denoising step
273+
kwargs["num_inference_steps"] = 4
274+
kwargs["strength"] = 0.25
275+
276+
return np.array([latent.cpu().numpy() for latent in pipe.to(device)(**kwargs).images])
245277

246278

247279
def benchmark_signal_integrity(pipe,
248280
candidates,
249-
sample_input_batch,
250281
nbits,
251282
cumulative,
252283
in_ngroups=1,
@@ -263,7 +294,7 @@ def benchmark_signal_integrity(pipe,
263294

264295
# If reference outputs are not provided, treat current pipe as reference
265296
if ref_out is None:
266-
ref_out = run_pipe(pipe, sample_input_batch)
297+
ref_out = run_pipe(pipe)
267298

268299
for candidate in tqdm(candidates):
269300
palettized = False
@@ -280,7 +311,7 @@ def benchmark_signal_integrity(pipe,
280311
if not palettized:
281312
raise KeyError(name)
282313

283-
test_out = run_pipe(pipe, sample_input_batch)
314+
test_out = run_pipe(pipe)
284315

285316
if not cumulative:
286317
restore_weight(module, orig_weight)
@@ -304,11 +335,11 @@ def descending_psnr_order(results):
304335
def simulate_quant_fn(ref_pipe, quantization_to_simulate):
305336
simulated_pipe = deepcopy(ref_pipe.to('cpu'))
306337
quantization_to_simulate(simulated_pipe.unet)
307-
simulated_out = run_pipe(simulated_pipe, RANDOM_TEST_DATA)
338+
simulated_out = run_pipe(simulated_pipe)
308339
del simulated_pipe
309340
gc.collect()
310341

311-
ref_out = run_pipe(ref_pipe, RANDOM_TEST_DATA)
342+
ref_out = run_pipe(ref_pipe)
312343
simulated_psnr = sum([
313344
float(f"{compute_psnr(r,t):.1f}")
314345
for r,t in zip(ref_out, simulated_out)
@@ -414,7 +445,7 @@ def main(args):
414445
logger.info("Done.")
415446

416447
# Cache reference outputs
417-
ref_out = run_pipe(pipe, RANDOM_TEST_DATA)
448+
ref_out = run_pipe(pipe)
418449

419450
# Bookkeeping
420451
os.makedirs(args.o, exist_ok=True)
@@ -442,7 +473,6 @@ def main(args):
442473
results['single_layer'][str(nbits)] = benchmark_signal_integrity(
443474
pipe,
444475
candidates,
445-
RANDOM_TEST_DATA,
446476
nbits,
447477
cumulative=False,
448478
ref_out=ref_out,
@@ -457,7 +487,6 @@ def main(args):
457487
results['cumulative'][str(nbits)] = benchmark_signal_integrity(
458488
deepcopy(pipe),
459489
sorted_candidates,
460-
RANDOM_TEST_DATA,
461490
nbits,
462491
cumulative=True,
463492
ref_out=ref_out,

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from diffusers import (
1515
StableDiffusionPipeline,
1616
StableDiffusionXLPipeline,
17+
StableDiffusionXLImg2ImgPipeline,
1718
ControlNetModel
1819
)
1920
import gc
@@ -1219,7 +1220,13 @@ def convert_controlnet(pipe, args):
12191220
gc.collect()
12201221

12211222
def get_pipeline(args):
1222-
if 'xl' in args.model_version:
1223+
if all(key in args.model_version for key in ['refiner', 'xl']):
1224+
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(args.model_version,
1225+
torch_dtype=torch.float16,
1226+
variant="fp16",
1227+
use_safetensors=True,
1228+
use_auth_token=True)
1229+
elif 'xl' in args.model_version:
12231230
pipe = StableDiffusionXLPipeline.from_pretrained(args.model_version,
12241231
torch_dtype=torch.float16,
12251232
variant="fp16",
@@ -1230,6 +1237,7 @@ def get_pipeline(args):
12301237
use_auth_token=True)
12311238
return pipe
12321239

1240+
12331241
def main(args):
12341242
os.makedirs(args.o, exist_ok=True)
12351243

0 commit comments

Comments
 (0)