Skip to content

Commit a4832b2

Browse files
committed
ControlNet quant with bitsandbytes and sdnq #42
Ability to quantize ControlNet, off by default when using --quantizer without --quantizer-map --control-nets URI now has a quantizer argument for individual quant settings Only works for directory / hub loads currently
1 parent 35daa36 commit a4832b2

File tree

5 files changed

+224
-15
lines changed

5 files changed

+224
-15
lines changed

dgenerate/console/schemas/submodels.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

dgenerate/pipelinewrapper/pipelines.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import diffusers.loaders
3535
import diffusers.loaders.single_file_utils
3636
import diffusers.quantizers.quantization_config
37-
import torch.nn
37+
import torch
3838
import torch.nn
3939
import transformers
4040

@@ -1226,7 +1226,8 @@ def create_diffusion_pipeline(
12261226
:param quantizer_uri: Optional ``--quantizer`` URI value
12271227
:param quantizer_map: Collection of pipeline submodule names to which quantization should be applied when
12281228
``quantizer_uri`` is provided. Valid values include: ``unet``, ``transformer``, ``text_encoder``,
1229-
``text_encoder_2``, ``text_encoder_3``. If ``None``, all supported modules will be quantized.
1229+
``text_encoder_2``, ``text_encoder_3``, and ``controlnet``. If ``None``, all supported modules will be quantized,
1230+
except for ``controlnet``.
12301231
:param pag: Use perturbed attention guidance?
12311232
:param safety_checker: Safety checker enabled? default is ``False``
12321233
:param original_config: Optional original training config .yaml file path when loading a single file checkpoint.
@@ -2217,6 +2218,17 @@ def get_device_map_for_quantizer(quantizer_uri):
22172218
if quantizer_class is _uris.SDNQQuantizerUri:
22182219
sdnq_cast_hack = True
22192220

2221+
# Check controlnet URIs
2222+
if controlnet_uris:
2223+
for controlnet_uri in controlnet_uris:
2224+
parsed_uri = _uris.ControlNetUri.parse(controlnet_uri, model_type=model_type)
2225+
uri_quant_check.append(parsed_uri)
2226+
if parsed_uri.quantizer:
2227+
manual_quantizer_components.add('controlnet')
2228+
quantizer_class = _uris.get_quantizer_uri_class(parsed_uri.quantizer)
2229+
if quantizer_class is _uris.SDNQQuantizerUri:
2230+
sdnq_cast_hack = True
2231+
22202232
if quantizer_uri or any(p.quantizer for p in uri_quant_check):
22212233
# for now, just knock out anything cached on the gpu, such as the last pipeline
22222234
# the quantized pipeline modules are likely going to go straight onto the GPU
@@ -2367,6 +2379,20 @@ def sdnq_forward(og_forward, model, *args, **kwargs):
23672379
kwargs[k] = v.to(dtype=model.dtype)
23682380
return og_forward(*args, **kwargs)
23692381

2382+
def controlnet_quant_forward(og_forward, model, *args, **kwargs):
2383+
"""
2384+
Forward function for quantized controlnets that casts inputs to the model's dtype.
2385+
This is needed because diffusers doesn't handle controlnet quantization state internally.
2386+
"""
2387+
args = list(args)
2388+
for i, arg in enumerate(args):
2389+
if isinstance(arg, torch.Tensor):
2390+
args[i] = arg.to(dtype=model.dtype)
2391+
for k, v in kwargs.items():
2392+
if isinstance(v, torch.Tensor):
2393+
kwargs[k] = v.to(dtype=model.dtype)
2394+
return og_forward(*args, **kwargs)
2395+
23702396
def load_unet(uri: _uris.UNetUri, unet_class):
23712397
unet_model = uri.load(
23722398
variant_fallback=variant,
@@ -2736,13 +2762,41 @@ def load_default_text_encoder(encoder, encoder_name):
27362762

27372763
parsed_controlnet_uris.append(parsed_controlnet_uri)
27382764

2739-
new_net = parsed_controlnet_uri.load(
2765+
# Apply global quantizer if controlnet doesn't have
2766+
# its own quantizer and should be quantized
2767+
controlnet_uri_to_load = parsed_controlnet_uri
2768+
if not parsed_controlnet_uri.quantizer and should_apply_quantizer('controlnet'):
2769+
# Create a new URI with the global quantizer
2770+
controlnet_uri_to_load = _uris.ControlNetUri(
2771+
model=parsed_controlnet_uri.model,
2772+
revision=parsed_controlnet_uri.revision,
2773+
variant=parsed_controlnet_uri.variant,
2774+
subfolder=parsed_controlnet_uri.subfolder,
2775+
dtype=parsed_controlnet_uri.dtype,
2776+
scale=parsed_controlnet_uri.scale,
2777+
start=parsed_controlnet_uri.start,
2778+
end=parsed_controlnet_uri.end,
2779+
mode=parsed_controlnet_uri.mode,
2780+
quantizer=quantizer_uri,
2781+
model_type=parsed_controlnet_uri.model_type
2782+
)
2783+
2784+
new_net = controlnet_uri_to_load.load(
27402785
use_auth_token=auth_token,
27412786
dtype_fallback=dtype,
27422787
local_files_only=local_files_only,
2743-
no_cache=model_cpu_offload or sequential_cpu_offload
2788+
no_cache=model_cpu_offload or sequential_cpu_offload,
2789+
device_map=get_device_map_for_quantizer(controlnet_uri_to_load.quantizer)
27442790
)
27452791

2792+
# Apply casting hack for quantized controlnets
2793+
if controlnet_uri_to_load.quantizer:
2794+
new_net.forward = functools.partial(
2795+
controlnet_quant_forward,
2796+
new_net.forward,
2797+
new_net
2798+
)
2799+
27462800
_messages.debug_log(lambda:
27472801
f'Added Torch ControlNet: "{controlnet_uri}" '
27482802
f'to pipeline: "{pipeline_class.__name__}"')

dgenerate/pipelinewrapper/uris/controlneturi.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from dgenerate.pipelinewrapper.uris import util as _util
3737

3838
_controlnet_uri_parser = _textprocessing.ConceptUriParser(
39-
'ControlNet', ['scale', 'start', 'end', 'mode', 'revision', 'variant', 'subfolder', 'dtype'])
39+
'ControlNet', ['scale', 'start', 'end', 'mode', 'revision', 'variant', 'subfolder', 'dtype', 'quantizer'])
4040

4141
_controlnet_cache = _d_memoize.create_object_cache(
4242
'controlnet', cache_type=_memory.SizedConstrainedObjectCache
@@ -172,6 +172,13 @@ def model_type(self) -> _enums.ModelType:
172172
"""
173173
return self._model_type
174174

175+
@property
176+
def quantizer(self) -> _types.OptionalUri:
177+
"""
178+
--quantizer URI override
179+
"""
180+
return self._quantizer
181+
175182
def __init__(self,
176183
model: str,
177184
revision: _types.OptionalString,
@@ -182,6 +189,7 @@ def __init__(self,
182189
start: float = 0.0,
183190
end: float = 1.0,
184191
mode: int | str | FluxControlNetUnionUriModes | SDXLControlNetUnionUriModes | None = None,
192+
quantizer: _types.OptionalUri = None,
185193
model_type: _enums.ModelType = _enums.ModelType.SD):
186194
"""
187195
:param model: model path
@@ -193,15 +201,24 @@ def __init__(self,
193201
:param start: controlnet guidance start value
194202
:param end: controlnet guidance end value
195203
:param mode: Flux / SDXL Union controlnet mode.
204+
:param quantizer: --quantizer URI override
196205
:param model_type: Model type this ControlNet will be attached to.
197206
198-
:raises InvalidControlNetUriError: If ``dtype`` is passed an invalid data type string.
207+
:raises InvalidControlNetUriError: If ``dtype`` is passed an invalid data type string,
208+
or if ``model`` points to a single file and ``quantizer`` is specified (not supported).
199209
"""
200210

211+
if _hfhub.is_single_file_model_load(model):
212+
if quantizer:
213+
raise _exceptions.InvalidControlNetUriError(
214+
'specifying a ControlNet quantizer URI is only supported for Hugging Face '
215+
'repository loads from a repo slug or disk path, single file loads are not supported.')
216+
201217
self._model = model
202218
self._revision = revision
203219
self._variant = variant
204220
self._subfolder = subfolder
221+
self._quantizer = quantizer
205222
self._model_type = model_type
206223

207224
if isinstance(mode, str):
@@ -232,6 +249,7 @@ def load(self,
232249
use_auth_token: _types.OptionalString = None,
233250
local_files_only: bool = False,
234251
no_cache: bool = False,
252+
device_map: str | None = None,
235253
model_class:
236254
type[diffusers.ControlNetModel] |
237255
type[diffusers.ControlNetUnionModel] |
@@ -255,6 +273,8 @@ def load(self,
255273
256274
:param no_cache: If True, force the returned object not to be cached by the memoize decorator.
257275
276+
:param device_map: device placement strategy for quantized models, defaults to ``None``
277+
258278
:param model_class: What class of controlnet model should be loaded?
259279
if ``None`` is specified, load based off :py:attr:`ControlNetUri.model_type`
260280
and provided URI arguments.
@@ -283,6 +303,7 @@ def cache_all(e):
283303
use_auth_token,
284304
local_files_only,
285305
no_cache,
306+
device_map,
286307
model_class)
287308

288309

@@ -305,6 +326,7 @@ def _load(self,
305326
use_auth_token: _types.OptionalString = None,
306327
local_files_only: bool = False,
307328
no_cache: bool = False,
329+
device_map: str | None = None,
308330
model_class:
309331
type[diffusers.ControlNetModel] |
310332
type[diffusers.ControlNetUnionModel] |
@@ -329,6 +351,14 @@ def _load(self,
329351
torch_dtype = _enums.get_torch_dtype(
330352
dtype_fallback if self.dtype is None else self.dtype)
331353

354+
if self.quantizer:
355+
quant_config = _util.get_quantizer_uri_class(
356+
self.quantizer,
357+
_exceptions.InvalidControlNetUriError
358+
).parse(self.quantizer).to_config(torch_dtype)
359+
else:
360+
quant_config = None
361+
332362
if single_file_load_path:
333363

334364
estimated_memory_usage = _pipelinewrapper_util.estimate_model_memory_use(
@@ -366,7 +396,9 @@ def _load(self,
366396
subfolder=self.subfolder,
367397
torch_dtype=torch_dtype,
368398
token=use_auth_token,
369-
local_files_only=local_files_only)
399+
local_files_only=local_files_only,
400+
quantization_config=quant_config,
401+
device_map=device_map)
370402

371403
_messages.debug_log('Estimated Torch ControlNet Memory Use:',
372404
_memory.bytes_best_human_unit(estimated_memory_usage))
@@ -376,7 +408,7 @@ def _load(self,
376408
# noinspection PyTypeChecker
377409
return new_net, _d_memoize.CachedObjectMetadata(
378410
size=estimated_memory_usage,
379-
skip=no_cache
411+
skip=self.quantizer or no_cache
380412
)
381413

382414
@staticmethod
@@ -450,6 +482,7 @@ def parse(uri: _types.Uri,
450482
start=start,
451483
end=end,
452484
mode=mode,
485+
quantizer=r.args.get('quantizer', None),
453486
model_type=model_type
454487
)
455488

docs/manual.rst

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9677,7 +9677,8 @@ Quantization
96779677
============
96789678

96799679
Quantization via ``bitsandbytes`` and ``sdnq`` is supported for certain
9680-
diffusion submodels, for instance, the unet/transformer, and all text encoders.
9680+
diffusion submodels, for instance, the unet/transformer, all text encoders,
9681+
and controlnet models.
96819682

96829683
It is also supported for certain plugins which utilize LLMs, such as the
96839684
``magicprompt`` upscaler, and ``llm4gen`` prompt weighter.
@@ -9700,7 +9701,8 @@ diffusion pipeline as it loads.
97009701

97019702
You can control which sub modules of the diffusion pipeline get quantized
97029703
by using the ``--quantizer-map`` argument, which accepts a list
9703-
of ``diffusers`` module names, e.g.
9704+
of ``diffusers`` module names, e.g. ``unet``, ``text_encoder``, ``text_encoder_2``,
9705+
``transformer``, ``controlnet``, etc.
97049706

97059707
.. code-block:: bash
97069708

@@ -9720,7 +9722,7 @@ of ``diffusers`` module names, e.g.
97209722

97219723

97229724
Quantization URI can also be supplied via sub-model URIs, the arguments
9723-
``--unet``, ``--transformer``, and ``--text-encoders`` all support a ``quantizer``
9725+
``--unet``, ``--transformer``, ``--text-encoders``, and ``--control-nets`` all support a ``quantizer``
97249726
sub URI argument for specifying the quantization backend for that particular sub-model.
97259727

97269728
This allows you to set specific quantization settings for sub-models individually.
@@ -9743,6 +9745,65 @@ dgenerate as a URI argument seperator.
97439745
--prompts "a cute cat"
97449746

97459747

9748+
ControlNet Quantization
9749+
-----------------------
9750+
ControlNet models are **NOT** quantized by default when using the global ``--quantizer``
9751+
argument. To quantize ControlNets, you must either:
9752+
9753+
1. Add ``controlnet`` to the ``--quantizer-map`` list to apply global quantization
9754+
2. Specify individual quantization settings per ControlNet using the ``quantizer`` URI argument
9755+
9756+
.. code-block:: bash
9757+
9758+
#!/usr/bin/env bash
9759+
9760+
# Method 1: Global quantization with controlnet in quantizer-map
9761+
9762+
dgenerate stabilityai/stable-diffusion-xl-base-1.0 \
9763+
--model-type sdxl \
9764+
--dtype float16 \
9765+
--variant fp16 \
9766+
--quantizer "bnb;bits=8" \
9767+
--quantizer-map unet text_encoder text_encoder_2 controlnet \
9768+
--control-nets "diffusers/controlnet-canny-sdxl-1.0" \
9769+
--inference-steps 30 \
9770+
--guidance-scales 5 \
9771+
--prompts "a cute cat"
9772+
9773+
.. code-block:: bash
9774+
9775+
#!/usr/bin/env bash
9776+
9777+
# Method 2: Individual ControlNet quantization
9778+
9779+
dgenerate stabilityai/stable-diffusion-xl-base-1.0 \
9780+
--model-type sdxl \
9781+
--dtype float16 \
9782+
--variant fp16 \
9783+
--control-nets 'diffusers/controlnet-canny-sdxl-1.0;quantizer="bnb;bits=4"' \
9784+
--inference-steps 30 \
9785+
--guidance-scales 5 \
9786+
--prompts "a cute cat"
9787+
9788+
.. code-block:: bash
9789+
9790+
#!/usr/bin/env bash
9791+
9792+
# ControlNet NOT quantized, only unet and text encoders
9793+
9794+
dgenerate stabilityai/stable-diffusion-xl-base-1.0 \
9795+
--model-type sdxl \
9796+
--dtype float16 \
9797+
--variant fp16 \
9798+
--quantizer "bnb;bits=8" \
9799+
--control-nets "diffusers/controlnet-canny-sdxl-1.0" \
9800+
--inference-steps 30 \
9801+
--guidance-scales 5 \
9802+
--prompts "a cute cat"
9803+
9804+
ControlNet quantization is only supported for Hugging Face repository loads
9805+
and local directory paths. Single file ControlNet loads do not support quantization.
9806+
97469807
Quantizer usage documentation can be obtained with ``--quantizer-help`` or the
97479808
equivalent ``\quantizer_help`` config directive, you can use this argument or
97489809
directive to list quantization backend names, when you supply backend names as

0 commit comments

Comments
 (0)