Skip to content

Commit d98473d

Browse files
committed
refactor
1 parent 6de34fe commit d98473d

File tree

8 files changed

+238
-233
lines changed

8 files changed

+238
-233
lines changed

src/diffusers/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,10 @@
7878
else:
7979
_import_structure["hooks"].extend(
8080
[
81+
"FasterCacheConfig",
8182
"HookRegistry",
8283
"PyramidAttentionBroadcastConfig",
84+
"apply_faster_cache",
8385
"apply_pyramid_attention_broadcast",
8486
]
8587
)
@@ -287,7 +289,6 @@
287289
"CogView3PlusPipeline",
288290
"ConsisIDPipeline",
289291
"CycleDiffusionPipeline",
290-
"FasterCacheConfig",
291292
"FluxControlImg2ImgPipeline",
292293
"FluxControlInpaintPipeline",
293294
"FluxControlNetImg2ImgPipeline",
@@ -434,7 +435,6 @@
434435
"WuerstchenCombinedPipeline",
435436
"WuerstchenDecoderPipeline",
436437
"WuerstchenPriorPipeline",
437-
"apply_fastercache",
438438
]
439439
)
440440

@@ -599,7 +599,13 @@
599599
except OptionalDependencyNotAvailable:
600600
from .utils.dummy_pt_objects import * # noqa F403
601601
else:
602-
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
602+
from .hooks import (
603+
FasterCacheConfig,
604+
HookRegistry,
605+
PyramidAttentionBroadcastConfig,
606+
apply_faster_cache,
607+
apply_pyramid_attention_broadcast,
608+
)
603609
from .models import (
604610
AllegroTransformer3DModel,
605611
AsymmetricAutoencoderKL,
@@ -782,7 +788,6 @@
782788
CogView3PlusPipeline,
783789
ConsisIDPipeline,
784790
CycleDiffusionPipeline,
785-
FasterCacheConfig,
786791
FluxControlImg2ImgPipeline,
787792
FluxControlInpaintPipeline,
788793
FluxControlNetImg2ImgPipeline,
@@ -927,7 +932,6 @@
927932
WuerstchenCombinedPipeline,
928933
WuerstchenDecoderPipeline,
929934
WuerstchenPriorPipeline,
930-
apply_fastercache,
931935
)
932936

933937
try:

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
if is_torch_available():
5+
from .faster_cache import FasterCacheConfig, apply_faster_cache
56
from .hooks import HookRegistry, ModelHook
67
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
78
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

src/diffusers/pipelines/fastercache_utils.py renamed to src/diffusers/hooks/faster_cache.py

Lines changed: 202 additions & 200 deletions
Large diffs are not rendered by default.

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig:
8787

8888
def __repr__(self) -> str:
8989
return (
90-
f"PyramidAttentionBroadcastConfig("
90+
f"PyramidAttentionBroadcastConfig(\n"
9191
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
9292
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
9393
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"

src/diffusers/pipelines/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
"StableDiffusionMixin",
5959
"ImagePipelineOutput",
6060
]
61-
_import_structure["faster_cache_utils"] = ["FasterCacheConfig", "apply_fastercache"]
6261
_import_structure["deprecated"].extend(
6362
[
6463
"PNDMPipeline",
@@ -451,7 +450,6 @@
451450
from .ddpm import DDPMPipeline
452451
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
453452
from .dit import DiTPipeline
454-
from .fastercache_utils import FasterCacheConfig, apply_fastercache
455453
from .latent_diffusion import LDMSuperResolutionPipeline
456454
from .pipeline_utils import (
457455
AudioPipelineOutput,

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,21 @@
22
from ..utils import DummyObject, requires_backends
33

44

5+
class FasterCacheConfig(metaclass=DummyObject):
6+
_backends = ["torch"]
7+
8+
def __init__(self, *args, **kwargs):
9+
requires_backends(self, ["torch"])
10+
11+
@classmethod
12+
def from_config(cls, *args, **kwargs):
13+
requires_backends(cls, ["torch"])
14+
15+
@classmethod
16+
def from_pretrained(cls, *args, **kwargs):
17+
requires_backends(cls, ["torch"])
18+
19+
520
class HookRegistry(metaclass=DummyObject):
621
_backends = ["torch"]
722

@@ -32,6 +47,10 @@ def from_pretrained(cls, *args, **kwargs):
3247
requires_backends(cls, ["torch"])
3348

3449

50+
def apply_faster_cache(*args, **kwargs):
51+
requires_backends(apply_faster_cache, ["torch"])
52+
53+
3554
def apply_pyramid_attention_broadcast(*args, **kwargs):
3655
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
3756

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -392,21 +392,6 @@ def from_pretrained(cls, *args, **kwargs):
392392
requires_backends(cls, ["torch", "transformers"])
393393

394394

395-
class FasterCacheConfig(metaclass=DummyObject):
396-
_backends = ["torch", "transformers"]
397-
398-
def __init__(self, *args, **kwargs):
399-
requires_backends(self, ["torch", "transformers"])
400-
401-
@classmethod
402-
def from_config(cls, *args, **kwargs):
403-
requires_backends(cls, ["torch", "transformers"])
404-
405-
@classmethod
406-
def from_pretrained(cls, *args, **kwargs):
407-
requires_backends(cls, ["torch", "transformers"])
408-
409-
410395
class FluxControlImg2ImgPipeline(metaclass=DummyObject):
411396
_backends = ["torch", "transformers"]
412397

@@ -2565,7 +2550,3 @@ def from_config(cls, *args, **kwargs):
25652550
@classmethod
25662551
def from_pretrained(cls, *args, **kwargs):
25672552
requires_backends(cls, ["torch", "transformers"])
2568-
2569-
2570-
def apply_fastercache(*args, **kwargs):
2571-
requires_backends(apply_fastercache, ["torch", "transformers"])

tests/pipelines/test_pipelines_common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
StableDiffusionPipeline,
3030
StableDiffusionXLPipeline,
3131
UNet2DConditionModel,
32-
apply_fastercache,
32+
apply_faster_cache,
3333
)
3434
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
3535
from diffusers.image_processor import VaeImageProcessor
@@ -2479,21 +2479,21 @@ def test_fastercache_basic_warning_or_errors_raised(self):
24792479
# Check if warning is raised when no FasterCacheConfig is provided
24802480
pipe = self.pipeline_class(**components)
24812481
with CaptureLogger(logger) as cap_logger:
2482-
apply_fastercache(pipe)
2482+
apply_faster_cache(pipe)
24832483
self.assertTrue("No FasterCacheConfig provided" in cap_logger.out)
24842484

24852485
# Check if warning is raise when no attention_weight_callback is provided
24862486
pipe = self.pipeline_class(**components)
24872487
with CaptureLogger(logger) as cap_logger:
24882488
config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None)
2489-
apply_fastercache(pipe, config)
2489+
apply_faster_cache(pipe, config)
24902490
self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out)
24912491

24922492
# Check if error raised when unsupported tensor format used
24932493
pipe = self.pipeline_class(**components)
24942494
with self.assertRaises(ValueError):
24952495
config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC")
2496-
apply_fastercache(pipe, config)
2496+
apply_faster_cache(pipe, config)
24972497

24982498
def test_fastercache_inference(self, expected_atol: float = 0.1):
24992499
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -2509,7 +2509,7 @@ def test_fastercache_inference(self, expected_atol: float = 0.1):
25092509
original_image_slice = output.flatten()
25102510
original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:]))
25112511

2512-
apply_fastercache(pipe, self.fastercache_config)
2512+
apply_faster_cache(pipe, self.fastercache_config)
25132513

25142514
inputs = self.get_dummy_inputs(device)
25152515
inputs["num_inference_steps"] = 4
@@ -2541,7 +2541,7 @@ def test_fastercache_state(self):
25412541
pipe = self.pipeline_class(**components)
25422542
pipe.set_progress_bar_config(disable=None)
25432543

2544-
apply_fastercache(pipe, self.fastercache_config)
2544+
apply_faster_cache(pipe, self.fastercache_config)
25452545

25462546
expected_hooks = 0
25472547
if self.fastercache_config.spatial_attention_block_skip_range is not None:

0 commit comments

Comments
 (0)