Skip to content

Commit 70cfcce

Browse files
authored
Merge pull request #354 from jakesabathia2/eng/PR-henry-refactor-chunk-script
Refactor the model chunking script to work with all version of coremltools
2 parents cf16df8 + c6884bf commit 70cfcce

File tree

1 file changed

+38
-8
lines changed

1 file changed

+38
-8
lines changed

python_coreml_stable_diffusion/chunk_mlprogram.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def _get_op_idx_split_location(prog: Program):
101101
""" Find the op that approximately bisects the graph as measure by weights size on each side
102102
"""
103103
main_block = prog.functions["main"]
104+
main_block.operations = list(main_block.operations)
104105
total_size_in_mb = 0
105106

106107
for op in main_block.operations:
@@ -132,6 +133,7 @@ def _get_first_chunk_outputs(block, op_idx):
132133
# to the second program (all ops from op_idx+1 till the end). These all vars need to be made the output
133134
# of the first program and the input of the second program
134135
boundary_vars = set()
136+
block.operations = list(block.operations)
135137
for i in range(op_idx + 1):
136138
op = block.operations[i]
137139
if not op.op_type.startswith("const"):
@@ -181,6 +183,7 @@ def _make_second_chunk_prog(prog, op_idx):
181183
boundary_vars = _get_first_chunk_outputs(block, op_idx)
182184

183185
# This op will not be included in this program. Its output var will be made into an input
186+
block.operations = list(block.operations)
184187
boundary_op = block.operations[op_idx]
185188

186189
# Add all boundary ops as inputs
@@ -228,7 +231,8 @@ def _make_second_chunk_prog(prog, op_idx):
228231
return prog
229232

230233

231-
def main(args):
234+
def _legacy_model_chunking(args):
235+
# TODO: Remove this method after setting the coremltools dependency >= 8.0
232236
os.makedirs(args.o, exist_ok=True)
233237

234238
# Check filename extension
@@ -307,13 +311,6 @@ def main(args):
307311
second_chunk_model=model_chunk2,
308312
)
309313

310-
# Remove original (non-chunked) model if requested
311-
if args.remove_original:
312-
logger.info(
313-
"Removing original (non-chunked) model at {args.mlpackage_path}")
314-
shutil.rmtree(args.mlpackage_path)
315-
logger.info("Done.")
316-
317314
if args.merge_chunks_in_pipeline_model:
318315
# Make a single pipeline model to manage the model chunks
319316
pipeline_model = ct.utils.make_pipeline(model_chunk1, model_chunk2)
@@ -342,6 +339,39 @@ def main(args):
342339
logger.info("Done.")
343340

344341

342+
def main(args):
343+
ct_version = ct.__version__
344+
345+
if ct_version != "8.0b2" and ct_version < "8.0":
346+
# With coremltools version <= 8.0b1,
347+
# we use the legacy implementation.
348+
# TODO: Remove the logic after setting the coremltools dependency >= 8.0.
349+
logger.info(
350+
f"coremltools version {ct_version} detected. Recommended upgrading the package version to "
351+
f"'8.0b2' when you running chunk_mlprogram.py script for the latest supports and bug fixes."
352+
)
353+
_legacy_model_chunking(args)
354+
else:
355+
# Starting from coremltools==8.0b2, there is this `bisect_model` API that
356+
# we can directly call into.
357+
from coremltools.models.utils import bisect_model
358+
logger.info(f"Start chunking model {args.mlpackage_path} into two pieces.")
359+
ct.models.utils.bisect_model(
360+
model=args.mlpackage_path,
361+
output_dir=args.o,
362+
merge_chunks_to_pipeline=args.merge_chunks_in_pipeline_model,
363+
check_output_correctness=args.check_output_correctness,
364+
)
365+
logger.info(f"Model chunking is done.")
366+
367+
# Remove original (non-chunked) model if requested
368+
if args.remove_original:
369+
logger.info(
370+
"Removing original (non-chunked) model at {args.mlpackage_path}")
371+
shutil.rmtree(args.mlpackage_path)
372+
logger.info("Done.")
373+
374+
345375
if __name__ == "__main__":
346376
parser = argparse.ArgumentParser()
347377
parser.add_argument(

0 commit comments

Comments
 (0)