Skip to content

Commit 5c5b1a4

Browse files
authored
Prevent the PyTorch compiler from compiling progress bars.
1 parent cd90a74 commit 5c5b1a4

File tree

1 file changed

+1
-10
lines changed

1 file changed

+1
-10
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,6 @@
102102

103103
logger = logging.get_logger(__name__)
104104

105-
class NotTQDMNoOp:
106-
def __init__(*args, **kwargs):
107-
return
108-
def __enter__(self, *args, **kwargs):
109-
return self
110-
def __exit__(*args, **kwargs):
111-
return
112-
def update(*args, **kwargs):
113-
return
114-
115105

116106
@dataclass
117107
class ImagePipelineOutput(BaseOutput):
@@ -1562,6 +1552,7 @@ def numpy_to_pil(images):
15621552
"""
15631553
return numpy_to_pil(images)
15641554

1555+
@torch.compiler.disable
15651556
def progress_bar(self, iterable=None, total=None):
15661557
if not hasattr(self, "_progress_bar_config"):
15671558
self._progress_bar_config = {}

0 commit comments

Comments
 (0)