Skip to content

Commit d229861

Browse files
committed
Support pipeline model creation for chunks
1 parent e3c1e36 commit d229861

File tree

1 file changed

+61
-27
lines changed

1 file changed

+61
-27
lines changed

python_coreml_stable_diffusion/chunk_mlprogram.py

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,39 +31,51 @@
3131
import time
3232

3333

34-
def _verify_output_correctness_of_chunks(full_model, first_chunk_model,
35-
second_chunk_model):
34+
def _verify_output_correctness_of_chunks(full_model,
35+
first_chunk_model=None,
36+
second_chunk_model=None,
37+
pipeline_model=None,):
3638
""" Verifies the end-to-end output correctness of full (original) model versus chunked models
3739
"""
3840
# Generate inputs for first chunk and full model
3941
input_dict = {}
4042
for input_desc in full_model._spec.description.input:
4143
input_dict[input_desc.name] = random_gen_input_feature_type(input_desc)
4244

43-
# Generate outputs for first chunk and full model
45+
# Generate outputs for full model
4446
outputs_from_full_model = full_model.predict(input_dict)
45-
outputs_from_first_chunk_model = first_chunk_model.predict(input_dict)
46-
47-
# Prepare inputs for second chunk model from first chunk's outputs and regular inputs
48-
second_chunk_input_dict = {}
49-
for input_desc in second_chunk_model._spec.description.input:
50-
if input_desc.name in outputs_from_first_chunk_model:
51-
second_chunk_input_dict[
52-
input_desc.name] = outputs_from_first_chunk_model[
47+
48+
if pipeline_model is not None:
49+
outputs_from_pipeline_model = pipeline_model.predict(input_dict)
50+
final_outputs = outputs_from_pipeline_model
51+
52+
elif first_chunk_model is not None and second_chunk_model is not None:
53+
# Generate outputs for first chunk
54+
outputs_from_first_chunk_model = first_chunk_model.predict(input_dict)
55+
56+
# Prepare inputs for second chunk model from first chunk's outputs and regular inputs
57+
second_chunk_input_dict = {}
58+
for input_desc in second_chunk_model._spec.description.input:
59+
if input_desc.name in outputs_from_first_chunk_model:
60+
second_chunk_input_dict[
61+
input_desc.name] = outputs_from_first_chunk_model[
62+
input_desc.name]
63+
else:
64+
second_chunk_input_dict[input_desc.name] = input_dict[
5365
input_desc.name]
54-
else:
55-
second_chunk_input_dict[input_desc.name] = input_dict[
56-
input_desc.name]
5766

58-
# Generate output for second chunk model
59-
outputs_from_second_chunk_model = second_chunk_model.predict(
60-
second_chunk_input_dict)
67+
# Generate output for second chunk model
68+
outputs_from_second_chunk_model = second_chunk_model.predict(
69+
second_chunk_input_dict)
70+
final_outputs = outputs_from_second_chunk_model
71+
else:
72+
raise ValueError
6173

6274
# Verify correctness across all outputs from second chunk and full model
6375
for out_name in outputs_from_full_model.keys():
6476
torch2coreml.report_correctness(
6577
original_outputs=outputs_from_full_model[out_name],
66-
final_outputs=outputs_from_second_chunk_model[out_name],
78+
final_outputs=final_outputs[out_name],
6779
log_prefix=f"{out_name}")
6880

6981

@@ -302,16 +314,32 @@ def main(args):
302314
shutil.rmtree(args.mlpackage_path)
303315
logger.info("Done.")
304316

305-
# Save the chunked models to disk
306-
out_path_chunk1 = os.path.join(args.o, name + "_chunk1.mlpackage")
307-
out_path_chunk2 = os.path.join(args.o, name + "_chunk2.mlpackage")
317+
if args.merge_chunks_in_pipeline_model:
318+
# Make a single pipeline model to manage the model chunks
319+
pipeline_model = ct.utils.make_pipeline(model_chunk1, model_chunk2)
320+
out_path_pipeline = os.path.join(args.o, name + "_chunked_pipeline.mlpackage")
308321

309-
logger.info(
310-
f"Saved chunks in {args.o} with the suffix _chunk1.mlpackage and _chunk2.mlpackage"
311-
)
312-
model_chunk1.save(out_path_chunk1)
313-
model_chunk2.save(out_path_chunk2)
314-
logger.info("Done.")
322+
# Save and reload to ensure CPU placement
323+
pipeline_model.save(out_path_pipeline)
324+
pipeline_model = ct.models.MLModel(out_path_pipeline)
325+
326+
if args.check_output_correctness:
327+
logger.info("Verifying output correctness of pipeline model")
328+
_verify_output_correctness_of_chunks(
329+
full_model=model,
330+
pipeline_model=pipeline_model,
331+
)
332+
else:
333+
# Save the chunked models to disk
334+
out_path_chunk1 = os.path.join(args.o, name + "_chunk1.mlpackage")
335+
out_path_chunk2 = os.path.join(args.o, name + "_chunk2.mlpackage")
336+
337+
logger.info(
338+
f"Saved chunks in {args.o} with the suffix _chunk1.mlpackage and _chunk2.mlpackage"
339+
)
340+
model_chunk1.save(out_path_chunk1)
341+
model_chunk2.save(out_path_chunk2)
342+
logger.info("Done.")
315343

316344

317345
if __name__ == "__main__":
@@ -341,6 +369,12 @@ def main(args):
341369
("If specified, compares the outputs of original Core ML model with that of pipelined CoreML model chunks and reports PSNR in dB. ",
342370
"Enabling this feature uses more memory. Disable it if your machine runs out of memory."
343371
))
372+
parser.add_argument(
373+
"--merge-chunks-in-pipeline-model",
374+
action="store_true",
375+
help=
376+
("If specified, model chunks are managed inside a single pipeline model for easier asset maintenance"
377+
))
344378

345379
args = parser.parse_args()
346380
main(args)

0 commit comments

Comments
 (0)