From cd90a74c7e19dbda3ae4da7e811ff87f57cf153d Mon Sep 17 00:00:00 2001 From: lsb Date: Tue, 3 Dec 2024 00:43:58 -0800 Subject: [PATCH 1/3] Avoid creating a progress bar when it is disabled. This is useful when exporting a pipeline, and allows a compiler to avoid trying to compile away tqdm. --- src/diffusers/pipelines/pipeline_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a4faacb44914..c94905782040 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -102,6 +102,16 @@ logger = logging.get_logger(__name__) +class NotTQDMNoOp: + def __init__(*args, **kwargs): + return + def __enter__(self, *args, **kwargs): + return self + def __exit__(*args, **kwargs): + return + def update(*args, **kwargs): + return + @dataclass class ImagePipelineOutput(BaseOutput): @@ -1560,6 +1570,8 @@ def progress_bar(self, iterable=None, total=None): f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) + if self._progress_bar_config.get('disable', False) == True: + return NotTQDMNoOp() if iterable is not None: return tqdm(iterable, **self._progress_bar_config) elif total is not None: From 5c5b1a4b6c2cfa32a089fa672abaeb9f8c578e59 Mon Sep 17 00:00:00 2001 From: lsb Date: Tue, 3 Dec 2024 08:41:23 -0800 Subject: [PATCH 2/3] Prevent the PyTorch compiler from compiling progress bars. --- src/diffusers/pipelines/pipeline_utils.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c94905782040..8f3df5f0805c 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -102,16 +102,6 @@ logger = logging.get_logger(__name__) -class NotTQDMNoOp: - def __init__(*args, **kwargs): - return - def __enter__(self, *args, **kwargs): - return self - def __exit__(*args, **kwargs): - return - def update(*args, **kwargs): - return - @dataclass class ImagePipelineOutput(BaseOutput): @@ -1562,6 +1552,7 @@ def numpy_to_pil(images): """ return numpy_to_pil(images) + @torch.compiler.disable def progress_bar(self, iterable=None, total=None): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} From 874dd6b994d3ac01a403ba57640e2c614991a6fb Mon Sep 17 00:00:00 2001 From: lsb Date: Tue, 3 Dec 2024 08:42:07 -0800 Subject: [PATCH 3/3] Update pipeline_utils.py --- src/diffusers/pipelines/pipeline_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8f3df5f0805c..5a4219adcb37 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1561,8 +1561,6 @@ def progress_bar(self, iterable=None, total=None): f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) - if self._progress_bar_config.get('disable', False) == True: - return NotTQDMNoOp() if iterable is not None: return tqdm(iterable, **self._progress_bar_config) elif total is not None: