1313# limitations under the License.
1414
1515from dataclasses import dataclass
16- from typing import Callable , List , Optional , Tuple , Type , TypeVar
16+ from typing import Callable , List , Optional , Tuple
1717
1818import torch .nn as nn
1919
2828
2929_ATTENTION_CLASSES = (Attention ,)
3030
31- _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = [ "blocks" , "transformer_blocks" ]
32- _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = [ "temporal_transformer_blocks" ]
33- _CROSS_ATTENTION_BLOCK_IDENTIFIERS = [ "blocks" , "transformer_blocks" ]
31+ _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( "blocks" , "transformer_blocks" )
32+ _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = "temporal_transformer_blocks"
33+ _CROSS_ATTENTION_BLOCK_IDENTIFIERS = ( "blocks" , "transformer_blocks" )
3434
3535
3636@dataclass
3737class PyramidAttentionBroadcastConfig :
38- spatial_attention_block_skip = None
39- temporal_attention_block_skip = None
40- cross_attention_block_skip = None
41-
42- spatial_attention_timestep_skip_range = (100 , 800 )
43- temporal_attention_timestep_skip_range = (100 , 800 )
44- cross_attention_timestep_skip_range = (100 , 800 )
38+ spatial_attention_block_skip_range : Optional [int ] = None
39+ temporal_attention_block_skip_range : Optional [int ] = None
40+ cross_attention_block_skip_range : Optional [int ] = None
4541
46- spatial_attention_block_identifiers = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
47- temporal_attention_block_identifiers = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
48- cross_attention_block_identifiers = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
42+ spatial_attention_timestep_skip_range : Tuple [int , int ] = (100 , 800 )
43+ temporal_attention_timestep_skip_range : Tuple [int , int ] = (100 , 800 )
44+ cross_attention_timestep_skip_range : Tuple [int , int ] = (100 , 800 )
45+
46+ spatial_attention_block_identifiers : Tuple [str , ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
47+ temporal_attention_block_identifiers : Tuple [str , ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
48+ cross_attention_block_identifiers : Tuple [str , ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
4949
5050
5151class PyramidAttentionBroadcastState :
52- iteration = 0
52+ def __init__ (self ) -> None :
53+ self .iteration = 0
54+
55+ def reset_state (self ):
56+ self .iteration = 0
5357
5458
5559def apply_pyramid_attention_broadcast (
@@ -59,56 +63,105 @@ def apply_pyramid_attention_broadcast(
5963):
6064 if config is None :
6165 config = PyramidAttentionBroadcastConfig ()
62-
63- if config .spatial_attention_block_skip is None and config .temporal_attention_block_skip is None and config .cross_attention_block_skip is None :
66+
67+ if (
68+ config .spatial_attention_block_skip_range is None
69+ and config .temporal_attention_block_skip_range is None
70+ and config .cross_attention_block_skip_range is None
71+ ):
6472 logger .warning (
65- "Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip `, `temporal_attention_block_skip ` "
66- "or `cross_attention_block_skip ` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip =2`. "
73+ "Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range `, `temporal_attention_block_skip_range ` "
74+ "or `cross_attention_block_skip_range ` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range =2`. "
6775 "To avoid this warning, please set one of the above parameters."
6876 )
69- config .spatial_attention_block_skip = 2
70-
77+ config .spatial_attention_block_skip_range = 2
78+
7179 if denoiser is None :
7280 denoiser = pipeline .transformer if hasattr (pipeline , "transformer" ) else pipeline .unet
73-
81+
7482 for name , module in denoiser .named_modules ():
7583 if not isinstance (module , _ATTENTION_CLASSES ):
7684 continue
7785 if isinstance (module , Attention ):
7886 _apply_pyramid_attention_broadcast_on_attention_class (pipeline , name , module , config )
7987
8088
81- # def apply_pyramid_attention_broadcast_spatial(module: TypeVar[_ATTENTION_CLASSES], config: PyramidAttentionBroadcastConfig):
82- # hook = PyramidAttentionBroadcastHook(skip_callback=)
83- # add_hook_to_module(module)
89+ def apply_pyramid_attention_broadcast_on_module (
90+ module : Attention ,
91+ block_skip_range : int ,
92+ timestep_skip_range : Tuple [int , int ],
93+ current_timestep_callback : Callable [[], int ],
94+ ):
95+ module ._pyramid_attention_broadcast_state = PyramidAttentionBroadcastState ()
96+ min_timestep , max_timestep = timestep_skip_range
97+
98+ def skip_callback (attention_module : nn .Module ) -> bool :
99+ pab_state : PyramidAttentionBroadcastState = attention_module ._pyramid_attention_broadcast_state
100+ current_timestep = current_timestep_callback ()
101+ is_within_timestep_range = min_timestep < current_timestep < max_timestep
102+
103+ if is_within_timestep_range :
104+ # As soon as the current timestep is within the timestep range, we start skipping attention computation.
105+ # The following inference steps will compute the attention every `block_skip_range` steps.
106+ should_compute_attention = pab_state .iteration > 0 and pab_state .iteration % block_skip_range == 0
107+ pab_state .iteration += 1
108+ print (current_timestep , is_within_timestep_range , should_compute_attention )
109+ return not should_compute_attention
110+
111+ # We are still not yet in the phase of inference where skipping attention is possible without minimal quality
112+ # loss, as described in the paper. So, the attention computation cannot be skipped
113+ return False
84114
115+ hook = PyramidAttentionBroadcastHook (skip_callback = skip_callback )
116+ add_hook_to_module (module , hook , append = True )
85117
86- def _apply_pyramid_attention_broadcast_on_attention_class (pipeline : DiffusionPipeline , name : str , module : Attention , config : PyramidAttentionBroadcastConfig ):
118+
119+ def _apply_pyramid_attention_broadcast_on_attention_class (
120+ pipeline : DiffusionPipeline , name : str , module : Attention , config : PyramidAttentionBroadcastConfig
121+ ):
87122 # Similar check as PEFT to determine if a string layer name matches a module name
88123 is_spatial_self_attention = (
89- any (f"{ identifier } ." in name or identifier == name for identifier in config .spatial_attention_block_identifiers )
90- and config .spatial_attention_timestep_skip_range is not None
124+ any (
125+ f"{ identifier } ." in name or identifier == name for identifier in config .spatial_attention_block_identifiers
126+ )
127+ and config .spatial_attention_block_skip_range is not None
91128 and not module .is_cross_attention
92129 )
93130 is_temporal_self_attention = (
94- any (f"{ identifier } ." in name or identifier == name for identifier in config .temporal_attention_block_identifiers )
95- and config .temporal_attention_timestep_skip_range is not None
131+ any (
132+ f"{ identifier } ." in name or identifier == name
133+ for identifier in config .temporal_attention_block_identifiers
134+ )
135+ and config .temporal_attention_block_skip_range is not None
96136 and not module .is_cross_attention
97137 )
98138 is_cross_attention = (
99139 any (f"{ identifier } ." in name or identifier == name for identifier in config .cross_attention_block_identifiers )
100- and config .cross_attention_timestep_skip_range is not None
140+ and config .cross_attention_block_skip_range is not None
101141 and not module .is_cross_attention
102142 )
103143
144+ block_skip_range , timestep_skip_range = None , None
104145 if is_spatial_self_attention :
105- apply_pyramid_attention_broadcast_spatial (module , config )
146+ block_skip_range = config .spatial_attention_block_skip_range
147+ timestep_skip_range = config .spatial_attention_timestep_skip_range
106148 elif is_temporal_self_attention :
107- apply_pyramid_attention_broadcast_temporal (module , config )
149+ block_skip_range = config .temporal_attention_block_skip_range
150+ timestep_skip_range = config .temporal_attention_timestep_skip_range
108151 elif is_cross_attention :
109- apply_pyramid_attention_broadcast_cross (module , config )
110- else :
152+ block_skip_range = config .cross_attention_block_skip_range
153+ timestep_skip_range = config .cross_attention_timestep_skip_range
154+
155+ if block_skip_range is None or timestep_skip_range is None :
111156 logger .warning (f"Unable to apply Pyramid Attention Broadcast to the selected layer: { name } ." )
157+ return
158+
159+ def current_timestep_callback ():
160+ return pipeline ._current_timestep
161+
162+ apply_pyramid_attention_broadcast_on_module (
163+ module , block_skip_range , timestep_skip_range , current_timestep_callback
164+ )
112165
113166
114167class PyramidAttentionBroadcastMixin :
0 commit comments