Skip to content

Commit 7a5562b

Browse files
committed
update diffusers, resolve 8bit quant memory leak #14
pipeline components with 8 bit bitsandbytes quant layers cannot free VRAM without being garbage collected, i.e. they cannot be moved on to the CPU Avoid caching them entirely, except for the last called pipeline which is kept for sequential calls and destroyed when a differing pipeline is called DiffusionPipelineWrapper.__LAST_CALLED cannot hang onto the wrapper reference without interfering with memory management (modules get pinned on the GPU for 8bit quant because they cannot be moved with .to()), instead create API to recall the last used main or secondary pipeline using a factory which may return a direct reference to it from dgenerates object cache, or re-create it using its intial arguments. This avoids perma pinning the self._pipeline reference used for lazy init inside DiffusionPipelineWrapper and leaking VRAM, in particular with 8bit quant modules.
1 parent 4c3516f commit 7a5562b

File tree

7 files changed

+200
-131
lines changed

7 files changed

+200
-131
lines changed

dgenerate/batchprocess/configrunner.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,20 @@ def _clear_object_cache(args: collections.abc.Sequence[str]):
317317
@staticmethod
318318
def _list_object_caches(args: collections.abc.Sequence[str]):
319319
"""
320-
List object cache names that may be cleared with \\clear_object_cache.
320+
List object cache names (and memory footprint if applicable) that may be cleared with \\clear_object_cache.
321321
"""
322322

323-
_messages.log('Object cache names:\n')
323+
_messages.log('Object caches:\n')
324+
324325
for object_cache in _memoize.get_object_cache_names():
325-
_messages.log(' ' * 4 + '"' + object_cache + '"')
326+
bin = _memoize.get_object_cache(object_cache)
327+
if isinstance(bin, _memory.SizedConstrainedObjectCache):
328+
_messages.log(
329+
' ' * 4 + '"' + object_cache +
330+
f'": {len(bin)} objects, cpu side RAM - {_memory.bytes_best_human_unit(bin.size)}'
331+
)
332+
else:
333+
_messages.log(' ' * 4 + '"' + object_cache + f'": {len(bin)} objects')
326334

327335
return 0
328336

dgenerate/imageprocessors/adetailer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,10 @@ def _adetailer(self, image):
341341
if self._pipe:
342342
last_pipe = self._pipe
343343
else:
344-
last_pipe = _pipelinewrapper.DiffusionPipelineWrapper.last_called_wrapper()
344+
last_pipe = _pipelinewrapper.DiffusionPipelineWrapper.recall_last_used_main_pipeline()
345345
if last_pipe is not None:
346346
# we only want the primary pipe, not the sdxl refiner for instance
347-
last_pipe = last_pipe.recall_main_pipeline().pipeline
347+
last_pipe = last_pipe.pipeline
348348

349349
if last_pipe is None:
350350
raise self.argument_error(

dgenerate/memoize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def values(self):
118118
"""
119119
return list(self.__cache.values())
120120

121+
def __len__(self):
122+
return len(self.__cache)
123+
121124
def clear(self, collect=True):
122125
"""
123126
Clear the cache and trigger callbacks.

dgenerate/pipelinewrapper/pipelines.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1911,6 +1911,18 @@ def _enforce_torch_pipeline_cache_size(new_pipeline_size):
19111911
new_object_size=new_pipeline_size)
19121912

19131913

1914+
def _check_for_8bit_bnb_quant_uris(uris: list):
1915+
for uri in uris:
1916+
if uri is None:
1917+
continue
1918+
uri_obj = _uris.get_quantizer_uri_class(uri).parse(uri)
1919+
if isinstance(uri_obj, _uris.BNBQuantizerUri):
1920+
if uri_obj.bits == 8:
1921+
return True
1922+
1923+
return False
1924+
1925+
19141926
@_memoize(_torch_pipeline_cache,
19151927
exceptions={'local_files_only'},
19161928
hasher=_torch_args_hasher,
@@ -2623,6 +2635,17 @@ def _handle_generic_pipeline_load_failure(e):
26232635

26242636
_messages.debug_log(f'Finished Creating Torch Pipeline: "{pipeline_class.__name__}"')
26252637

2638+
# modules quantized in 8 bit by bitsandbytes cannot be moved off the GPU,
2639+
# which results in VRAM memory leaks in dgenerates caching system, just
2640+
# do not cache these pipelines for anything more than repeated calls, the
2641+
# only way they get removed from VRAM is if their reference count is zero
2642+
bnb_8bit_components = _check_for_8bit_bnb_quant_uris(
2643+
[quantizer_uri, unet_uri, transformer_uri] +
2644+
[u.quantizer for u in uri_quant_check])
2645+
2646+
if bnb_8bit_components:
2647+
_messages.debug_log(f'Pipeline has 8bit bnb components, not entering cache: "{pipeline_class.__name__}"')
2648+
26262649
# noinspection PyTypeChecker
26272650
return TorchPipelineCreationResult(
26282651
model_path=model_path,
@@ -2636,7 +2659,7 @@ def _handle_generic_pipeline_load_failure(e):
26362659
parsed_textual_inversion_uris=parsed_textual_inversion_uris,
26372660
parsed_controlnet_uris=parsed_controlnet_uris,
26382661
parsed_t2i_adapter_uris=parsed_t2i_adapter_uris
2639-
), _d_memoize.CachedObjectMetadata(size=estimated_memory_usage)
2662+
), _d_memoize.CachedObjectMetadata(size=estimated_memory_usage, skip=bnb_8bit_components)
26402663

26412664

26422665
__all__ = _types.module_all()

dgenerate/pipelinewrapper/wrapper.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,36 @@ class DiffusionPipelineWrapper:
190190
Monolithic diffusion pipelines wrapper.
191191
"""
192192

193-
__LAST_CALLED = None
193+
__LAST_RECALL_PIPELINE: _pipelines.TorchPipelineFactory = None
194+
__LAST_RECALL_SECONDARY_PIPELINE: _pipelines.TorchPipelineFactory = None
194195

195196
@staticmethod
196-
def last_called_wrapper() -> typing.Optional['DiffusionPipelineWrapper']:
197+
def recall_last_used_main_pipeline() -> typing.Optional[_pipelines.TorchPipelineCreationResult]:
197198
"""
198-
Return a reference to the last :py:class:`DiffusionPipelineWrapper`
199-
that successfully executed an image generation.
199+
Return a reference to the last :py:class:`dgenerate.pipelinewrapper.pipelines.TorchPipelineCreationResult`
200+
for the pipeline that successfully executed an image generation.
200201
201-
:return: :py:class:`DiffusionPipelineWrapper`
202+
This may recreate the pipeline if it is not cached.
203+
204+
If no image generation has occurred, this will return ``None``.
205+
206+
:return: :py:class:`dgenerate.pipelinewrapper.pipelines.TorchPipelineCreationResult` or ``None``
207+
"""
208+
return DiffusionPipelineWrapper.__LAST_RECALL_PIPELINE()
209+
210+
@staticmethod
211+
def recall_last_used_secondary_pipeline() -> typing.Optional[_pipelines.TorchPipelineCreationResult]:
212+
"""
213+
Return a reference to the last :py:class:`dgenerate.pipelinewrapper.pipelines.TorchPipelineCreationResult`
214+
for the secondary pipeline (refiner / stable cascade decoder) that successfully executed an image generation.
215+
216+
This may recreate the pipeline if it is not cached.
217+
218+
If no image generation has occurred or no secondary pipeline has been called, this will return ``None``.
219+
220+
:return: :py:class:`dgenerate.pipelinewrapper.pipelines.TorchPipelineCreationResult` or ``None``
202221
"""
203-
return DiffusionPipelineWrapper.__LAST_CALLED
222+
return DiffusionPipelineWrapper.__LAST_RECALL_SECONDARY_PIPELINE()
204223

205224
def __str__(self):
206225
return f'{self.__class__.__name__}({str(_types.get_public_attributes(self))})'
@@ -524,7 +543,7 @@ def _init(
524543
self._pipeline_type = None
525544
self._local_files_only = local_files_only
526545
self._recall_main_pipeline = None
527-
self._recall_refiner_pipeline = None
546+
self._recall_secondary_pipeline = None
528547
self._model_extra_modules = model_extra_modules
529548
self._second_model_extra_modules = second_model_extra_modules
530549
self._model_cpu_offload = model_cpu_offload
@@ -2642,23 +2661,23 @@ def recall_main_pipeline(self) -> _pipelines.PipelineCreationResult:
26422661

26432662
return self._recall_main_pipeline()
26442663

2645-
def recall_refiner_pipeline(self) -> _pipelines.PipelineCreationResult:
2664+
def recall_secondary_pipeline(self) -> _pipelines.PipelineCreationResult:
26462665
"""
2647-
Fetch the last used refiner pipeline creation result, possibly the
2648-
pipeline will be recreated if no longer in the in memory cache.
2649-
If there is no refiner pipeline currently created, which will be the
2650-
case if an image was never generated yet or a refiner model was not
2666+
Fetch the last used refiner / stable cascade decoder pipeline creation result,
2667+
possibly the pipeline will be recreated if no longer in the in memory cache.
2668+
If there is no refiner / decoder pipeline currently created, which will be the
2669+
case if an image was never generated yet or a refiner / decoder model was not
26512670
specified, :py:exc:`RuntimeError` will be raised.
26522671
26532672
:raises RuntimeError:
26542673
26552674
:return: :py:class:`dgenerate.pipelinewrapper.PipelineCreationResult`
26562675
"""
26572676

2658-
if self._recall_refiner_pipeline is None:
2677+
if self._recall_secondary_pipeline is None:
26592678
raise RuntimeError('Cannot recall refiner pipeline as one has not been created.')
26602679

2661-
return self._recall_refiner_pipeline()
2680+
return self._recall_secondary_pipeline()
26622681

26632682
def _lazy_init_pipeline(self, args: DiffusionArguments):
26642683

@@ -2703,7 +2722,7 @@ def _lazy_init_pipeline(self, args: DiffusionArguments):
27032722
self._pipeline_type = pipeline_type
27042723

27052724
self._recall_main_pipeline = None
2706-
self._recall_refiner_pipeline = None
2725+
self._recall_secondary_pipeline = None
27072726

27082727
if self._parsed_adetailer_detector_uris:
27092728
pipeline_type = _enums.PipelineType.INPAINT
@@ -2739,7 +2758,7 @@ def _lazy_init_pipeline(self, args: DiffusionArguments):
27392758
creation_result = self._recall_main_pipeline()
27402759
self._pipeline = creation_result.pipeline
27412760

2742-
self._recall_s_cascade_decoder_pipeline = _pipelines.TorchPipelineFactory(
2761+
self._recall_secondary_pipeline = _pipelines.TorchPipelineFactory(
27432762
model_path=self._parsed_s_cascade_decoder_uri.model,
27442763
model_type=_enums.ModelType.TORCH_S_CASCADE_DECODER,
27452764
pipeline_type=_enums.PipelineType.TXT2IMG,
@@ -2764,7 +2783,7 @@ def _lazy_init_pipeline(self, args: DiffusionArguments):
27642783
model_cpu_offload=self._second_model_cpu_offload,
27652784
sequential_cpu_offload=self._second_model_sequential_offload)
27662785

2767-
creation_result = self._recall_s_cascade_decoder_pipeline()
2786+
creation_result = self._recall_secondary_pipeline()
27682787
self._s_cascade_decoder_pipeline = creation_result.pipeline
27692788

27702789
elif self._sdxl_refiner_uri is not None:
@@ -2822,7 +2841,7 @@ def _lazy_init_pipeline(self, args: DiffusionArguments):
28222841
else:
28232842
refiner_extra_modules = self._second_model_extra_modules
28242843

2825-
self._recall_refiner_pipeline = _pipelines.TorchPipelineFactory(
2844+
self._recall_secondary_pipeline = _pipelines.TorchPipelineFactory(
28262845
model_path=self._parsed_sdxl_refiner_uri.model,
28272846
model_type=_enums.ModelType.TORCH_SDXL,
28282847
pipeline_type=refiner_pipeline_type,
@@ -2848,7 +2867,7 @@ def _lazy_init_pipeline(self, args: DiffusionArguments):
28482867
model_cpu_offload=self._second_model_cpu_offload,
28492868
sequential_cpu_offload=self._second_model_sequential_offload
28502869
)
2851-
self._sdxl_refiner_pipeline = self._recall_refiner_pipeline().pipeline
2870+
self._sdxl_refiner_pipeline = self._recall_secondary_pipeline().pipeline
28522871
else:
28532872
self._recall_main_pipeline = _pipelines.TorchPipelineFactory(
28542873
model_path=self._model_path,
@@ -3189,7 +3208,8 @@ def __call__(self, args: DiffusionArguments | None = None, **kwargs) -> Pipeline
31893208
result = self._call_torch(pipeline_args=pipeline_args,
31903209
user_args=copy_args)
31913210

3192-
DiffusionPipelineWrapper.__LAST_CALLED = self
3211+
DiffusionPipelineWrapper.__LAST_RECALL_PIPELINE = self._recall_main_pipeline
3212+
DiffusionPipelineWrapper.__LAST_RECALL_SECONDARY_PIPELINE = self._recall_secondary_pipeline
31933213

31943214
return result
31953215

0 commit comments

Comments
 (0)