@@ -101,6 +101,7 @@ def _get_op_idx_split_location(prog: Program):
101
101
""" Find the op that approximately bisects the graph as measure by weights size on each side
102
102
"""
103
103
main_block = prog .functions ["main" ]
104
+ main_block .operations = list (main_block .operations )
104
105
total_size_in_mb = 0
105
106
106
107
for op in main_block .operations :
@@ -132,6 +133,7 @@ def _get_first_chunk_outputs(block, op_idx):
132
133
# to the second program (all ops from op_idx+1 till the end). These all vars need to be made the output
133
134
# of the first program and the input of the second program
134
135
boundary_vars = set ()
136
+ block .operations = list (block .operations )
135
137
for i in range (op_idx + 1 ):
136
138
op = block .operations [i ]
137
139
if not op .op_type .startswith ("const" ):
@@ -181,6 +183,7 @@ def _make_second_chunk_prog(prog, op_idx):
181
183
boundary_vars = _get_first_chunk_outputs (block , op_idx )
182
184
183
185
# This op will not be included in this program. Its output var will be made into an input
186
+ block .operations = list (block .operations )
184
187
boundary_op = block .operations [op_idx ]
185
188
186
189
# Add all boundary ops as inputs
@@ -228,7 +231,8 @@ def _make_second_chunk_prog(prog, op_idx):
228
231
return prog
229
232
230
233
231
- def main (args ):
234
+ def _legacy_model_chunking (args ):
235
+ # TODO: Remove this method after setting the coremltools dependency >= 8.0
232
236
os .makedirs (args .o , exist_ok = True )
233
237
234
238
# Check filename extension
@@ -307,13 +311,6 @@ def main(args):
307
311
second_chunk_model = model_chunk2 ,
308
312
)
309
313
310
- # Remove original (non-chunked) model if requested
311
- if args .remove_original :
312
- logger .info (
313
- "Removing original (non-chunked) model at {args.mlpackage_path}" )
314
- shutil .rmtree (args .mlpackage_path )
315
- logger .info ("Done." )
316
-
317
314
if args .merge_chunks_in_pipeline_model :
318
315
# Make a single pipeline model to manage the model chunks
319
316
pipeline_model = ct .utils .make_pipeline (model_chunk1 , model_chunk2 )
@@ -342,6 +339,39 @@ def main(args):
342
339
logger .info ("Done." )
343
340
344
341
342
+ def main (args ):
343
+ ct_version = ct .__version__
344
+
345
+ if ct_version != "8.0b2" and ct_version < "8.0" :
346
+ # With coremltools version <= 8.0b1,
347
+ # we use the legacy implementation.
348
+ # TODO: Remove the logic after setting the coremltools dependency >= 8.0.
349
+ logger .info (
350
+ f"coremltools version { ct_version } detected. Recommended upgrading the package version to "
351
+ f"'8.0b2' when you running chunk_mlprogram.py script for the latest supports and bug fixes."
352
+ )
353
+ _legacy_model_chunking (args )
354
+ else :
355
+ # Starting from coremltools==8.0b2, there is this `bisect_model` API that
356
+ # we can directly call into.
357
+ from coremltools .models .utils import bisect_model
358
+ logger .info (f"Start chunking model { args .mlpackage_path } into two pieces." )
359
+ ct .models .utils .bisect_model (
360
+ model = args .mlpackage_path ,
361
+ output_dir = args .o ,
362
+ merge_chunks_to_pipeline = args .merge_chunks_in_pipeline_model ,
363
+ check_output_correctness = args .check_output_correctness ,
364
+ )
365
+ logger .info (f"Model chunking is done." )
366
+
367
+ # Remove original (non-chunked) model if requested
368
+ if args .remove_original :
369
+ logger .info (
370
+ "Removing original (non-chunked) model at {args.mlpackage_path}" )
371
+ shutil .rmtree (args .mlpackage_path )
372
+ logger .info ("Done." )
373
+
374
+
345
375
if __name__ == "__main__" :
346
376
parser = argparse .ArgumentParser ()
347
377
parser .add_argument (
0 commit comments