Skip to content

Commit 223ff38

Browse files
minettekaumNils Fleischmann
andauthored
feat: add pruner algroithm (#470)
* added padded_pruning algorithm and flux_tiny_random_with_tokenizer to fixtures * fixing typo * Add co-author Co-authored-by: Nils Fleischmann <nils.fleischmann@outlook.com> --------- Co-authored-by: Nils Fleischmann <nils.fleischmann@outlook.com>
1 parent d4b259c commit 223ff38

File tree

3 files changed

+257
-0
lines changed

3 files changed

+257
-0
lines changed
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import functools
18+
import inspect
19+
from collections.abc import Iterable
20+
from typing import Any
21+
22+
from ConfigSpace import OrdinalHyperparameter
23+
24+
from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
25+
from pruna.algorithms.base.tags import AlgorithmTag
26+
from pruna.config.smash_config import SmashConfigPrefixWrapper
27+
from pruna.engine.model_checks import is_diffusers_model
28+
from pruna.engine.save import SAVE_FUNCTIONS
29+
30+
31+
class PaddingPruner(PrunaAlgorithmBase):
32+
"""
33+
Implement Padding Pruning for Diffusers pipelines.
34+
35+
Padding Pruning removes unused padding tokens from the prompt embedding of diffusers pipelines.
36+
"""
37+
38+
algorithm_name: str = "padding_pruning"
39+
group_tags: list[AlgorithmTag] = [AlgorithmTag.PRUNER]
40+
references: dict[str, str] = {}
41+
tokenizer_required: bool = True
42+
processor_required: bool = False
43+
runs_on: list[str] = ["cpu", "cuda", "accelerate"]
44+
dataset_required: bool = False
45+
save_fn = SAVE_FUNCTIONS.reapply
46+
compatible_before: Iterable[str | AlgorithmTag] = ["qkv_diffusers"]
47+
compatible_after: Iterable[str | AlgorithmTag] = [
48+
AlgorithmTag.CACHER,
49+
"hyper",
50+
"torch_compile",
51+
"stable_fast",
52+
"hqq_diffusers",
53+
"diffusers_int8",
54+
"torchao",
55+
"flash_attn3",
56+
"ring_attn",
57+
]
58+
59+
def get_hyperparameters(self) -> list:
60+
"""
61+
Get the hyperparameters for the Prompt Pruner.
62+
63+
Returns
64+
-------
65+
list
66+
A list of hyperparameters.
67+
"""
68+
return [
69+
OrdinalHyperparameter(
70+
"min_sequence_length",
71+
sequence=[32, 64, 128, 256],
72+
default_value=64,
73+
meta=dict(desc="Minimum sequence length used to embed a prompt."),
74+
),
75+
]
76+
77+
def model_check_fn(self, model: Any) -> bool:
78+
"""
79+
Check if the model is a diffusers pipeline with a max_sequence_length parameter.
80+
81+
Parameters
82+
----------
83+
model : Any
84+
The model instance to check.
85+
86+
Returns
87+
-------
88+
bool
89+
True if the model is a diffusers pipeline with a max_sequence_length parameter.
90+
"""
91+
if not is_diffusers_model(model):
92+
return False
93+
signature = inspect.signature(model.__call__)
94+
return "max_sequence_length" in signature.parameters
95+
96+
def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
97+
"""
98+
Apply Prompt Pruning to the pipeline.
99+
100+
Parameters
101+
----------
102+
model : Any
103+
The pipeline to apply prompt pruning to.
104+
smash_config : SmashConfigPrefixWrapper
105+
Configuration settings for the pruning.
106+
107+
Returns
108+
-------
109+
Any
110+
The pipeline with Prompt Pruning enabled.
111+
"""
112+
min_sequence_length = smash_config["min_sequence_length"]
113+
model.padding_pruning_helper = PaddingPruningHelper(model, min_sequence_length, smash_config.tokenizer)
114+
model.padding_pruning_helper.enable()
115+
return model
116+
117+
def import_algorithm_packages(self) -> dict[str, Any]:
118+
"""
119+
Import necessary algorithm packages.
120+
121+
Returns
122+
-------
123+
dict
124+
An empty dictionary as no packages are imported in this implementation.
125+
"""
126+
return dict()
127+
128+
129+
class PaddingPruningHelper:
130+
"""
131+
Helper for Padding Pruning.
132+
133+
Parameters
134+
----------
135+
pipe : Any
136+
The diffusers pipeline to wrap.
137+
min_tokens : int
138+
The minimum number of tokens to embed a prompt.
139+
tokenizer : Any
140+
The tokenizer of the pipeline.
141+
"""
142+
143+
def __init__(self, pipe: Any, min_tokens: int, tokenizer: Any) -> None:
144+
self.pipe = pipe
145+
self.min_tokens = min_tokens
146+
self.tokenizer = tokenizer
147+
148+
def enable(self) -> None:
149+
"""Enable prompt pruning by wrapping the pipe."""
150+
self.wrap_pipe(self.pipe)
151+
152+
def disable(self) -> None:
153+
"""Disable prompt pruning by unwrapping the pipe."""
154+
if self.pipe_call:
155+
self.pipe.__call__ = self.pipe_call
156+
157+
def wrap_pipe(self, pipe: Any) -> None:
158+
"""
159+
Wrap the call method of the pipe to adjust the max sequence length.
160+
161+
Parameters
162+
----------
163+
pipe : Any
164+
The diffusers pipeline to wrap.
165+
"""
166+
pipe_call = pipe.__call__
167+
self.pipe_call = pipe_call
168+
signature = inspect.signature(pipe_call)
169+
default_max_sequence_length = signature.parameters["max_sequence_length"].default
170+
171+
@functools.wraps(pipe_call)
172+
def wrapped_call(*args, **kwargs): # noqa: ANN201
173+
# while a natural approach would be to remove all padding tokens,
174+
# we found this to degrade the quality of the generated images
175+
# for this reason, we usually round to the nearest order of two
176+
# and use this as the max sequence length
177+
178+
# the min_tokens parameter controls the minimum for the max sequence length
179+
min_sequence_length = self.min_tokens
180+
# we use the default value as the maximum value for the max sequence length
181+
max_sequence_length = kwargs.get("max_sequence_length", default_max_sequence_length)
182+
183+
prompts = self._extract_prompts(args, kwargs)
184+
max_num_tokens = max(len(self.tokenizer.encode(p)) for p in prompts)
185+
186+
sequence_length = min_sequence_length
187+
while max_num_tokens > sequence_length:
188+
sequence_length *= 2
189+
if sequence_length >= max_sequence_length:
190+
sequence_length = max_sequence_length
191+
kwargs["max_sequence_length"] = sequence_length
192+
return pipe_call(*args, **kwargs)
193+
194+
pipe.__call__ = wrapped_call
195+
196+
def _extract_prompts(self, args: Any, kwargs: Any) -> list[str]:
197+
"""Extract the prompts from the args and kwargs of the pipe call."""
198+
prompts: list[str] = []
199+
200+
# the first arguments of diffusers pipelines are usually the prompts
201+
for arg in args:
202+
if isinstance(arg, str):
203+
prompts.append(arg)
204+
elif isinstance(arg, list):
205+
if len(arg) > 0 and isinstance(arg[0], str):
206+
prompts.extend(arg)
207+
else:
208+
break
209+
210+
for kwarg in kwargs:
211+
if kwarg.startswith("prompt"):
212+
prompt = kwargs[kwarg]
213+
if isinstance(prompt, str):
214+
prompts.append(prompt)
215+
elif isinstance(prompt, list):
216+
prompts.extend(prompt)
217+
if kwarg.startswith("negative_prompt"):
218+
negative_prompt = kwargs[kwarg]
219+
if isinstance(negative_prompt, str):
220+
prompts.append(negative_prompt)
221+
elif isinstance(negative_prompt, list):
222+
prompts.extend(negative_prompt)
223+
return prompts
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from pruna.algorithms.padding_pruning import PaddingPruner
2+
from pruna.engine.pruna_model import PrunaModel
3+
4+
from .base_tester import AlgorithmTesterBase
5+
6+
7+
class TestPaddingPruning(AlgorithmTesterBase):
8+
"""Test the padding pruning algorithm."""
9+
10+
models = ["flux_tiny_random_with_tokenizer"]
11+
reject_models = ["opt_tiny_random"]
12+
allow_pickle_files = False
13+
algorithm_class = PaddingPruner
14+
metrics = ["cmmd"]
15+
16+
def post_smash_hook(self, model: PrunaModel) -> None:
17+
"""Hook to modify the model after smashing."""
18+
assert hasattr(model, "padding_pruning_helper")
19+
model.text_encoder.resize_token_embeddings(model.smash_config.tokenizer.vocab_size)
20+
21+
if hasattr(model, "text_encoder_2"):
22+
model.text_encoder_2.resize_token_embeddings(model.smash_config.tokenizer.vocab_size)

tests/fixtures.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,15 @@ def get_diffusers_model(model_id: str, **kwargs: dict[str, Any]) -> tuple[Any, S
9191
return model, smash_config
9292

9393

94+
def get_diffusers_model_with_tokenizer(model_id: str, **kwargs: dict[str, Any]) -> tuple[Any, SmashConfig]:
95+
"""Get a diffusers model for image generation."""
96+
model, _ = get_diffusers_model(model_id, **kwargs)
97+
smash_config = SmashConfig()
98+
smash_config.add_data("LAION256")
99+
smash_config.add_tokenizer("openai/clip-vit-base-patch32")
100+
return model, smash_config
101+
102+
94103
def get_automodel_transformers(model_id: str, **kwargs: dict[str, Any]) -> tuple[Any, SmashConfig]:
95104
"""Get an AutoModelForCausalLM model for text generation."""
96105
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
@@ -184,6 +193,9 @@ def get_autoregressive_text_to_image_model(model_id: str) -> tuple[Any, SmashCon
184193
"sd_tiny_random": partial(get_diffusers_model, "dg845/tiny-random-stable-diffusion"),
185194
"sana_tiny_random": partial(get_diffusers_model, "katuni4ka/tiny-random-sana"),
186195
"flux_tiny_random": partial(get_diffusers_model, "katuni4ka/tiny-random-flux", torch_dtype=torch.bfloat16),
196+
"flux_tiny_random_with_tokenizer": partial(
197+
get_diffusers_model_with_tokenizer, "katuni4ka/tiny-random-flux", torch_dtype=torch.float16
198+
),
187199
# text generation models
188200
"opt_tiny_random": partial(get_automodel_transformers, "yujiepan/opt-tiny-random"),
189201
"smollm_135m": partial(get_automodel_transformers, "HuggingFaceTB/SmolLM2-135M"),

0 commit comments

Comments
 (0)