Skip to content

Commit d456a97

Browse files
authored
Merge pull request #311 from TobyRoseman/mbp-new-api
Applying mixed bit compression using new optimize API
2 parents 7449ce4 + ec09fe0 commit d456a97

File tree

4 files changed

+30
-55
lines changed

4 files changed

+30
-55
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
*~
2+
13
# Swift Package
24
.DS_Store
35
/.build

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ Resources:
183183
<details>
184184
<summary> Details (Click to expand) </summary>
185185

186-
This section describes an advanced compression algorithm called [Mixed-Bit Palettization (MBP)](https://huggingface.co/blog/stable-diffusion-xl-coreml#what-is-mixed-bit-palettization) built on top of the [Post-Training Weight Palettization tools from coremltools-7.0](https://apple.github.io/coremltools/docs-guides/source/post-training-palettization.html).
186+
This section describes an advanced compression algorithm called [Mixed-Bit Palettization (MBP)](https://huggingface.co/blog/stable-diffusion-xl-coreml#what-is-mixed-bit-palettization) built on top of the [Post-Training Weight Palettization tools](https://apple.github.io/coremltools/docs-guides/source/post-training-palettization.html) and using the [Weights Metadata API](https://apple.github.io/coremltools/docs-guides/source/mlmodel-utilities.html#get-weights-metadata) from [coremltools](https://github.com/apple/coremltools).
187187

188188
MBP builds a per-layer "palettization recipe" by picking a suitable number of bits among the Neural Engine supported bit-widths of 1, 2, 4, 6 and 8 in order to achieve the minimum average bit-width while maintaining a desired level of signal strength. The signal strength is measured by comparing the compressed model's output to that of the original float16 model. Given the same random seed and text prompts, PSNR between denoised latents is computed. The compression rate will depend on the model version as well as the tolerance for signal loss (drop in PSNR) since this algorithm is adaptive.
189189

python_coreml_stable_diffusion/mixed_bit_compression_apply.py

Lines changed: 23 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1-
from pprint import pprint
21
import argparse
3-
import coremltools as ct
42
import gc
53
import json
64
import logging
7-
import numpy as np
85
import os
96

7+
import coremltools as ct
8+
import coremltools.optimize.coreml as cto
9+
import numpy as np
10+
1011
from python_coreml_stable_diffusion.torch2coreml import get_pipeline
1112
from python_coreml_stable_diffusion.mixed_bit_compression_pre_analysis import (
1213
NBITS,
1314
PALETTIZE_MIN_SIZE as MIN_SIZE
1415
)
1516

17+
1618
logging.basicConfig()
1719
logger = logging.getLogger(__name__)
1820
logger.setLevel(logging.INFO)
@@ -23,9 +25,6 @@ def main(args):
2325
coreml_model = ct.models.MLModel(args.mlpackage_path, compute_units=ct.ComputeUnit.CPU_ONLY)
2426
logger.info(f"Loaded {args.mlpackage_path}")
2527

26-
# Keep track of precision stats
27-
precision_stats = {nbits:{'num_tensors': 0, 'numel': 0} for nbits in NBITS}
28-
2928
# Load palettization recipe
3029
with open(args.pre_analysis_json_path, 'r') as f:
3130
pre_analysis = json.load(f)
@@ -62,53 +61,29 @@ def get_tensor_hash(tensor):
6261
del pipe
6362
gc.collect()
6463

65-
current_nbits: int
66-
67-
def op_selector(const):
68-
parameter_tensor = const.val.val
69-
if parameter_tensor.size < MIN_SIZE:
70-
return False
71-
72-
if parameter_tensor.dtype != np.float16:
73-
# These are the tensors that were compressed to look-up indices in previous passes
74-
return False
75-
76-
tensor_hash = get_tensor_hash(parameter_tensor)
77-
tensor_spec = f"{tensor_hash} with shape {parameter_tensor.shape}"
78-
79-
80-
hashes = list(hashed_recipe)
81-
pdist = np.abs(np.array(hashes) - tensor_hash)
64+
op_name_configs = {}
65+
weight_metadata = cto.get_weights_metadata(coreml_model, weight_threshold=MIN_SIZE)
66+
hashes = np.array(list(hashed_recipe))
67+
for name, metadata in weight_metadata.items():
68+
# Look up target bits for this weight
69+
tensor_hash = get_tensor_hash(metadata.val)
70+
pdist = np.abs(hashes - tensor_hash)
71+
assert(pdist.min() < 0.01)
8272
matched = pdist.argmin()
83-
logger.debug(f"{tensor_spec}: {tensor_hash} matched with {hashes[matched]} (hash error={pdist.min()})")
84-
8573
target_nbits = hashed_recipe[hashes[matched]]
86-
87-
do_palettize = current_nbits == target_nbits
88-
if do_palettize:
89-
logger.debug(f"{tensor_spec}: Palettizing to {target_nbits}-bit palette")
90-
precision_stats[current_nbits]['num_tensors'] += 1
91-
precision_stats[current_nbits]['numel'] += np.prod(parameter_tensor.shape)
92-
return True
93-
return False
94-
95-
for nbits in NBITS:
96-
logger.info(f"Processing tensors targeting {nbits}-bit palettes")
97-
current_nbits = nbits
98-
99-
config = ct.optimize.coreml.OptimizationConfig(
100-
global_config=ct.optimize.coreml.OpPalettizerConfig(mode="kmeans", nbits=nbits, weight_threshold=None,),
101-
is_deprecated=True,
102-
op_selector=op_selector,
74+
75+
if target_nbits == 16:
76+
continue
77+
78+
op_name_configs[name] = cto.OpPalettizerConfig(
79+
mode="kmeans",
80+
nbits=target_nbits,
81+
weight_threshold=int(MIN_SIZE)
10382
)
104-
coreml_model = ct.optimize.coreml.palettize_weights(coreml_model, config=config)
105-
logger.info(f"{precision_stats[nbits]['num_tensors']} tensors are palettized with {nbits} bits")
10683

84+
config = ct.optimize.coreml.OptimizationConfig(op_name_configs=op_name_configs)
85+
coreml_model = ct.optimize.coreml.palettize_weights(coreml_model, config)
10786

108-
tot_numel = sum([precision_stats[nbits]['numel'] for nbits in NBITS])
109-
final_size = sum([precision_stats[nbits]['numel'] * nbits for nbits in NBITS])
110-
logger.info(f"Palettization result: {final_size / tot_numel:.2f}-bits resulting in {final_size / (8*1e6)} MB")
111-
pprint(precision_stats)
11287
coreml_model.save(args.o)
11388

11489

python_coreml_stable_diffusion/mixed_bit_compression_pre_analysis.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import requests
2222
torch.set_grad_enabled(False)
2323

24-
from tqdm import tqdm, trange
24+
from tqdm import tqdm
2525

2626
# Bit-widths the Neural Engine is capable of accelerating
2727
NBITS = [1, 2, 4, 6, 8]
@@ -342,8 +342,8 @@ def simulate_quant_fn(ref_pipe, quantization_to_simulate):
342342

343343
ref_out = run_pipe(ref_pipe)
344344
simulated_psnr = sum([
345-
float(f"{compute_psnr(r,t):.1f}")
346-
for r,t in zip(ref_out, simulated_out)
345+
float(f"{compute_psnr(r, t):.1f}")
346+
for r, t in zip(ref_out, simulated_out)
347347
]) / len(ref_out)
348348

349349
return simulated_out, simulated_psnr
@@ -459,9 +459,7 @@ def main(args):
459459
json_name = f"{args.model_version.replace('/','-')}_palettization_recipe.json"
460460
candidates, sizes = get_palettizable_modules(pipe.unet)
461461

462-
sizes_table = {
463-
candidate:size for candidate, size in zip(candidates, sizes)
464-
}
462+
sizes_table = dict(zip(candidates, sizes))
465463

466464
if os.path.isfile(os.path.join(args.o, json_name)):
467465
with open(os.path.join(args.o, json_name), "r") as f:

0 commit comments

Comments
 (0)