Skip to content

Commit f7b2618

Browse files
committed
refactor the model chunking script to work with all version of coremltools
1 parent cf16df8 commit f7b2618

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

python_coreml_stable_diffusion/chunk_mlprogram.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def _get_op_idx_split_location(prog: Program):
101101
""" Find the op that approximately bisects the graph as measure by weights size on each side
102102
"""
103103
main_block = prog.functions["main"]
104+
main_block.operations = list(main_block.operations)
104105
total_size_in_mb = 0
105106

106107
for op in main_block.operations:
@@ -132,6 +133,7 @@ def _get_first_chunk_outputs(block, op_idx):
132133
# to the second program (all ops from op_idx+1 till the end). These all vars need to be made the output
133134
# of the first program and the input of the second program
134135
boundary_vars = set()
136+
block.operations = list(block.operations)
135137
for i in range(op_idx + 1):
136138
op = block.operations[i]
137139
if not op.op_type.startswith("const"):
@@ -181,6 +183,7 @@ def _make_second_chunk_prog(prog, op_idx):
181183
boundary_vars = _get_first_chunk_outputs(block, op_idx)
182184

183185
# This op will not be included in this program. Its output var will be made into an input
186+
block.operations = list(block.operations)
184187
boundary_op = block.operations[op_idx]
185188

186189
# Add all boundary ops as inputs
@@ -228,7 +231,8 @@ def _make_second_chunk_prog(prog, op_idx):
228231
return prog
229232

230233

231-
def main(args):
234+
def _legancy_model_chunking(args):
235+
# TODO: Remove this method after setting the coremltools dependency >= 8.0
232236
os.makedirs(args.o, exist_ok=True)
233237

234238
# Check filename extension
@@ -307,13 +311,6 @@ def main(args):
307311
second_chunk_model=model_chunk2,
308312
)
309313

310-
# Remove original (non-chunked) model if requested
311-
if args.remove_original:
312-
logger.info(
313-
"Removing original (non-chunked) model at {args.mlpackage_path}")
314-
shutil.rmtree(args.mlpackage_path)
315-
logger.info("Done.")
316-
317314
if args.merge_chunks_in_pipeline_model:
318315
# Make a single pipeline model to manage the model chunks
319316
pipeline_model = ct.utils.make_pipeline(model_chunk1, model_chunk2)
@@ -342,6 +339,37 @@ def main(args):
342339
logger.info("Done.")
343340

344341

342+
def main(args):
343+
ct_version = ct.__version__
344+
345+
if ct_version != "8.0b2" and ct_version < "8.0":
346+
# With coremltools version <= 8.0b1,
347+
# we use the lagancy implementation.
348+
# TODO: Remove the logic after setting the coremltools dependency >= 8.0.
349+
logger.info(
350+
f"coremltools version {ct_version} detected. Recommended upgrading the package version to "
351+
f"'8.0b2' when you running chunk_mlprogram.py script for the latest supports and bug fixes."
352+
)
353+
_legancy_model_chunking(args)
354+
else:
355+
# Starting from coremltools==8.0b2, there is this `bisect_model` API that
356+
# we can directly call into.
357+
from coremltools.models.utils import bisect_model
358+
ct.models.utils.bisect_model(
359+
model=args.mlpackage_path,
360+
output_dir=args.o,
361+
merge_chunks_to_pipeline=args.merge_chunks_in_pipeline_model,
362+
check_output_correctness=args.check_output_correctness,
363+
)
364+
365+
# Remove original (non-chunked) model if requested
366+
if args.remove_original:
367+
logger.info(
368+
"Removing original (non-chunked) model at {args.mlpackage_path}")
369+
shutil.rmtree(args.mlpackage_path)
370+
logger.info("Done.")
371+
372+
345373
if __name__ == "__main__":
346374
parser = argparse.ArgumentParser()
347375
parser.add_argument(

0 commit comments

Comments
 (0)