Skip to content

Commit 35613e5

Browse files
authored
update chunking script to work with models that contains compressed weights, i.e. constexpr* ops (#245)
1 parent 40ff5f5 commit 35613e5

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

python_coreml_stable_diffusion/chunk_mlprogram.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ def _get_op_idx_split_location(prog: Program):
105105
size_in_mb = op.val.val.size * op.val.val.itemsize / (1024 * 1024)
106106
cumulative_size_in_mb += size_in_mb
107107

108-
if (cumulative_size_in_mb > half_size and op.op_type != "const"
108+
# Note: The condition "not op.op_type.startswith("const")" is to make sure that the
109+
# incision op is neither of type "const" nor "constexpr_*" ops that
110+
# are used to store compressed weights
111+
if (cumulative_size_in_mb > half_size and not op.op_type.startswith("const")
109112
and len(op.outputs) == 1
110113
and len(op.outputs[0].child_ops) == 1):
111114
op_idx = main_block.operations.index(op)
@@ -192,6 +195,11 @@ def _make_second_chunk_prog(prog, op_idx):
192195
anchor_op=boundary_op,
193196
old_var=var,
194197
new_var=new_var,
198+
# This is needed if the program contains "constexpr_*" ops. In normal cases, there are stricter
199+
# rules for removing them, and their presence may prevent replacing this var.
200+
# However in this case, since we want to remove all the ops in chunk 1, we can safely
201+
# set this to True.
202+
force_replace=True,
195203
)
196204

197205
PASS_REGISTRY["common::dead_code_elimination"](prog)

0 commit comments

Comments
 (0)