Skip to content

Commit 3f7776d

Browse files
committed
support mapping.
1 parent 30b1ef2 commit 3f7776d

File tree

3 files changed

+86
-18
lines changed

3 files changed

+86
-18
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -769,9 +769,10 @@ def load_sub_model(
769769
and isinstance(quantization_config, PipelineQuantizationConfig)
770770
and issubclass(class_obj, torch.nn.Module)
771771
):
772-
exclude_modules = quantization_config.exclude_modules or []
773-
if name not in exclude_modules:
774-
model_quant_config = _resolve_quant_config(quantization_config, is_diffusers=is_diffusers_model)
772+
model_quant_config = quantization_config._resolve_quant_config(
773+
is_diffusers=is_diffusers_model, module_name=name
774+
)
775+
if model_quant_config is not None:
775776
loading_kwargs["quantization_config"] = model_quant_config
776777

777778
# check if the module is in a subdirectory
@@ -1085,20 +1086,33 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
10851086
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
10861087

10871088

1088-
def _resolve_quant_config(quant_config, is_diffusers=True):
1089+
def _resolve_quant_config(quant_config, is_diffusers=True, module_name=None):
10891090
if is_diffusers:
10901091
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING
10911092
else:
10921093
from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING
10931094

1094-
quant_backend = quant_config.quant_backend
1095-
if quant_backend not in AUTO_QUANTIZATION_CONFIG_MAPPING:
1096-
raise ValueError(
1097-
f"Provided {quant_backend=} was not found in the support quantizers. Available ones are: {AUTO_QUANTIZATION_CONFIG_MAPPING.keys()}."
1098-
)
1099-
1100-
quant_config_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_backend]
1095+
# Granular case.
1096+
if getattr(quant_config, "is_granular", False):
1097+
config = quant_config.mapping.get(module_name)
1098+
quant_backend = config.get("quant_backend")
1099+
if quant_backend not in AUTO_QUANTIZATION_CONFIG_MAPPING:
1100+
raise ValueError(
1101+
f"Module '{module_name}': Provided quant_backend={quant_backend} was not found. "
1102+
f"Available ones are: {list(AUTO_QUANTIZATION_CONFIG_MAPPING.keys())}."
1103+
)
1104+
quant_config_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_backend]
1105+
quant_kwargs = config.get("quant_kwargs")
11011106

1102-
quant_kwargs = quant_config.quant_kwargs
1103-
quant_config = quant_config_cls(**quant_kwargs)
1104-
return quant_config
1107+
return quant_config_cls(**quant_kwargs)
1108+
else:
1109+
# Global config case.
1110+
quant_backend = quant_config.quant_backend
1111+
if quant_backend not in AUTO_QUANTIZATION_CONFIG_MAPPING:
1112+
raise ValueError(
1113+
f"Provided quant_backend={quant_backend} was not found. "
1114+
f"Available ones are: {list(AUTO_QUANTIZATION_CONFIG_MAPPING.keys())}."
1115+
)
1116+
quant_config_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_backend]
1117+
quant_kwargs = quant_config.quant_kwargs
1118+
return quant_config_cls(**quant_kwargs)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
875875
}
876876
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
877877

878+
# TODO: add checking for quantization_config `mapping` i.e., if the modules specified there actually exist.
879+
#########################
880+
878881
# remove `null` components
879882
def load_module(name, value):
880883
if value[0] is None:

src/diffusers/quantizers/__init__.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,63 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from ..utils import is_transformers_available
1516
from .auto import DiffusersAutoQuantizer
1617
from .base import DiffusersQuantizer
1718

1819

1920
class PipelineQuantizationConfig:
20-
def __init__(self, quant_backend, quant_kwargs, exclude_modules):
21-
self.quant_backend = quant_backend
22-
self.quant_kwargs = quant_kwargs
23-
self.exclude_modules = exclude_modules
21+
def __init__(self, quant_backend: str, quant_kwargs: dict, exclude_modules: list, mapping: dict = None):
22+
if mapping is not None:
23+
self.mapping = mapping
24+
self.is_granular = True
25+
else:
26+
self.quant_backend = quant_backend
27+
self.quant_kwargs = quant_kwargs or {}
28+
self.exclude_modules = exclude_modules or []
29+
self.is_granular = False
30+
31+
self.post_init()
32+
33+
def post_init(self):
34+
if self.is_granular and self.mapping is None:
35+
raise ValueError(
36+
"In the granular case, a `mapping` defining the quantization configs"
37+
" for the desired modules have to be defined."
38+
)
39+
40+
def _resolve_quant_config(self, is_diffusers=True, module_name=None):
41+
if is_diffusers:
42+
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING
43+
else:
44+
if is_transformers_available():
45+
from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING
46+
47+
# Granular case.
48+
if self.is_granular:
49+
config = self.mapping.get(module_name)
50+
quant_backend = config.get("quant_backend")
51+
if quant_backend not in AUTO_QUANTIZATION_CONFIG_MAPPING:
52+
raise ValueError(
53+
f"Module '{module_name}': Provided quant_backend={quant_backend} was not found. "
54+
f"Available ones are: {list(AUTO_QUANTIZATION_CONFIG_MAPPING.keys())}."
55+
)
56+
quant_config_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_backend]
57+
quant_kwargs = config.get("quant_kwargs")
58+
59+
return quant_config_cls(**quant_kwargs)
60+
61+
# Global config case.
62+
elif module_name not in self.exclude_modules:
63+
quant_backend = self.quant_backend
64+
if quant_backend not in AUTO_QUANTIZATION_CONFIG_MAPPING:
65+
raise ValueError(
66+
f"Provided quant_backend={quant_backend} was not found. "
67+
f"Available ones are: {list(AUTO_QUANTIZATION_CONFIG_MAPPING.keys())}."
68+
)
69+
quant_config_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_backend]
70+
quant_kwargs = self.quant_kwargs
71+
return quant_config_cls(**quant_kwargs)
72+
73+
else:
74+
return None

0 commit comments

Comments
 (0)