Skip to content

Commit 3e04bce

Browse files
committed
On macOS 15+, set fast prediction optimization hint for the unet.
1 parent 32be4af commit 3e04bce

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

python_coreml_stable_diffusion/coreml_model.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,28 @@
1616

1717
import os
1818
import time
19+
import subprocess
20+
import sys
21+
22+
23+
def _macos_version():
24+
"""
25+
Returns macOS version as a tuple of integers. On non-Macs, returns an empty tuple.
26+
"""
27+
if sys.platform == "darwin":
28+
try:
29+
ver_str = subprocess.run(["sw_vers", "-productVersion"], stdout=subprocess.PIPE).stdout.decode('utf-8').strip('\n')
30+
return tuple([int(v) for v in ver_str.split(".")])
31+
except:
32+
raise Exception("Unable to determine the macOS version")
33+
return ()
1934

2035

2136
class CoreMLModel:
2237
""" Wrapper for running CoreML models using coremltools
2338
"""
2439

25-
def __init__(self, model_path, compute_unit, sources='packages'):
40+
def __init__(self, model_path, compute_unit, sources='packages', optimization_hints=None):
2641

2742
logger.info(f"Loading {model_path}")
2843

@@ -31,7 +46,10 @@ def __init__(self, model_path, compute_unit, sources='packages'):
3146
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")
3247

3348
self.model = ct.models.MLModel(
34-
model_path, compute_units=ct.ComputeUnit[compute_unit])
49+
model_path,
50+
compute_units=ct.ComputeUnit[compute_unit],
51+
optimization_hints=optimization_hints,
52+
)
3553
DTYPE_MAP = {
3654
65552: np.float16,
3755
65568: np.float32,
@@ -47,7 +65,11 @@ def __init__(self, model_path, compute_unit, sources='packages'):
4765
elif sources == 'compiled':
4866
assert os.path.exists(model_path) and model_path.endswith(".mlmodelc")
4967

50-
self.model = ct.models.CompiledMLModel(model_path, ct.ComputeUnit[compute_unit])
68+
self.model = ct.models.CompiledMLModel(
69+
model_path,
70+
compute_units=ct.ComputeUnit[compute_unit],
71+
optimization_hints=optimization_hints,
72+
)
5173

5274
# Grab expected inputs from metadata.json
5375
with open(os.path.join(model_path, 'metadata.json'), 'r') as f:
@@ -170,7 +192,15 @@ def _load_mlpackage(submodule_name,
170192
raise FileNotFoundError(
171193
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
172194

173-
return CoreMLModel(mlpackage_path, compute_unit, sources=sources)
195+
# On macOS 15+, set fast prediction optimization hint for the unet.
196+
optimization_hints = None
197+
if submodule_name == "unet" and _macos_version() >= (15, 0):
198+
optimization_hints = {"specializationStrategy": ct.SpecializationStrategy.FastPrediction}
199+
200+
return CoreMLModel(mlpackage_path,
201+
compute_unit,
202+
sources=sources,
203+
optimization_hints=optimization_hints)
174204

175205

176206
def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):

0 commit comments

Comments
 (0)