2121
2222from ...configuration_utils import ConfigMixin , register_to_config
2323from ..activations import get_activation
24+ from ..attention_processor import MultiscaleLinearAttention
2425from ..modeling_utils import ModelMixin
25- from ..normalization import RMSNorm
26-
27-
28- def get_norm_layer (name : Optional [str ] = "batch_norm" , num_features : Optional [int ] = None ) -> Optional [nn .Module ]:
29- if name is None :
30- norm = None
31- elif name == "rms_norm" :
32- norm = RMSNorm (num_features , eps = 1e-5 , elementwise_affine = True , bias = True )
33- elif name == "batch_norm" :
34- norm = nn .BatchNorm2d (num_features = num_features )
35- else :
36- raise ValueError (f"norm { name } is not supported" )
37- return norm
26+ from ..normalization import RMSNorm , get_normalization
3827
3928
4029class GLUMBConv (nn .Module ):
@@ -81,7 +70,7 @@ def __init__(
8170 self .nonlinearity = get_activation (act_fn ) if act_fn is not None else nn .Identity ()
8271 self .conv1 = nn .Conv2d (in_channels , in_channels , 3 , 1 , 1 )
8372 self .conv2 = nn .Conv2d (in_channels , out_channels , 3 , 1 , 1 , bias = False )
84- self .norm = get_norm_layer (norm_type , out_channels )
73+ self .norm = get_normalization (norm_type , out_channels )
8574
8675 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
8776 residual = hidden_states
@@ -93,149 +82,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9382 hidden_states = self .norm (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
9483 else :
9584 hidden_states = self .norm (hidden_states )
96- return hidden_states + residual
97-
98-
99- class MLAProjection (nn .Module ):
100- def __init__ (
101- self ,
102- in_channels : int ,
103- num_attention_heads : int ,
104- kernel_size : int ,
105- ) -> None :
106- super ().__init__ ()
107-
108- self .proj_in = nn .Conv2d (
109- 3 * in_channels ,
110- 3 * in_channels ,
111- kernel_size ,
112- padding = kernel_size // 2 ,
113- groups = 3 * in_channels ,
114- bias = False ,
115- )
116- self .proj_out = nn .Conv2d (
117- 3 * in_channels , 3 * in_channels , 1 , 1 , 0 , groups = 3 * num_attention_heads , bias = False
118- )
119-
120- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
121- hidden_states = self .proj_in (hidden_states )
122- hidden_states = self .proj_out (hidden_states )
123- return hidden_states
124-
125-
126- class LiteMLA (nn .Module ):
127- r"""Lightweight multi-scale linear attention"""
128-
129- def __init__ (
130- self ,
131- in_channels : int ,
132- out_channels : int ,
133- num_attention_heads : Optional [int ] = None ,
134- heads_ratio : float = 1.0 ,
135- attention_head_dim : int = 8 ,
136- norm_type : str = "batch_norm" ,
137- kernel_sizes : Tuple [int , ...] = (5 ,),
138- eps : float = 1e-15 ,
139- ):
140- super ().__init__ ()
141-
142- self .eps = eps
143- self .attention_head_dim = attention_head_dim
144- self .norm_type = norm_type
145-
146- num_attention_heads = (
147- int (in_channels // attention_head_dim * heads_ratio )
148- if num_attention_heads is None
149- else num_attention_heads
150- )
151- inner_dim = num_attention_heads * attention_head_dim
152-
153- self .to_qkv = nn .Conv2d (in_channels , 3 * inner_dim , 1 , 1 , 0 , bias = False )
154-
155- self .to_qkv_multiscale = nn .ModuleList ()
156- for kernel_size in kernel_sizes :
157- self .to_qkv_multiscale .append (MLAProjection (inner_dim , num_attention_heads , kernel_size ))
158-
159- self .kernel_nonlinearity = nn .ReLU ()
160- self .proj_out = nn .Conv2d (inner_dim * (1 + len (kernel_sizes )), out_channels , 1 , 1 , 0 , bias = False )
161- self .norm_out = get_norm_layer (norm_type , num_features = out_channels )
162-
163- def linear_attention (self , qkv : torch .Tensor ) -> torch .Tensor :
164- batch_size , _ , height , width = qkv .shape
165-
166- qkv = qkv .float ()
167- qkv = torch .reshape (qkv , (batch_size , - 1 , 3 * self .attention_head_dim , height * width ))
168-
169- query , key , value = (
170- qkv [:, :, 0 : self .attention_head_dim ],
171- qkv [:, :, self .attention_head_dim : 2 * self .attention_head_dim ],
172- qkv [:, :, 2 * self .attention_head_dim :],
173- )
174-
175- # lightweight linear attention
176- query = self .kernel_nonlinearity (query )
177- key = self .kernel_nonlinearity (key )
178- value = F .pad (value , (0 , 0 , 0 , 1 ), mode = "constant" , value = 1 )
179-
180- key_T = key .transpose (- 1 , - 2 )
181- scores = torch .matmul (value , key_T )
182- output = torch .matmul (scores , query )
183-
184- output = output .float ()
185- output = output [:, :, :- 1 ] / (output [:, :, - 1 :] + self .eps )
186- output = torch .reshape (output , (batch_size , - 1 , height , width ))
187-
188- return output
189-
190- def quadratic_attention (self , qkv : torch .Tensor ) -> torch .Tensor :
191- batch_size , _ , height , width = list (qkv .size ())
192-
193- qkv = torch .reshape (qkv , (batch_size , - 1 , 3 * self .attention_head_dim , height * width ))
194- query , key , value = (
195- qkv [:, :, 0 : self .attention_head_dim ],
196- qkv [:, :, self .attention_head_dim : 2 * self .attention_head_dim ],
197- qkv [:, :, 2 * self .attention_head_dim :],
198- )
199-
200- query = self .kernel_nonlinearity (query )
201- key = self .kernel_nonlinearity (key )
202-
203- scores = torch .matmul (key .transpose (- 1 , - 2 ), query )
204-
205- original_dtype = scores .dtype
206- scores = scores .float ()
207- scores = scores / (torch .sum (scores , dim = 2 , keepdim = True ) + self .eps )
208- scores = scores .to (original_dtype )
209-
210- output = torch .matmul (value , scores )
211- output = torch .reshape (output , (batch_size , - 1 , height , width ))
212-
213- return output
214-
215- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
216- residual = hidden_states
217-
218- qkv = self .to_qkv (hidden_states )
219-
220- multi_scale_qkv = [qkv ]
221- for block in self .to_qkv_multiscale :
222- multi_scale_qkv .append (block (qkv ))
223-
224- qkv = torch .cat (multi_scale_qkv , dim = 1 )
225-
226- height , width = qkv .shape [- 2 :]
227- if height * width > self .attention_head_dim :
228- hidden_states = self .linear_attention (qkv ).to (qkv .dtype )
229- else :
230- hidden_states = self .quadratic_attention (qkv )
231-
232- hidden_states = self .proj_out (hidden_states )
233-
234- if self .norm_type == "rms_norm" :
235- hidden_states = self .norm_out (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
236- else :
237- hidden_states = self .norm_out (hidden_states )
238-
85+
23986 return hidden_states + residual
24087
24188
@@ -247,10 +94,10 @@ def __init__(
24794 dim : int = 32 ,
24895 qkv_multiscales : Tuple [int , ...] = (5 ,),
24996 norm_type : str = "batch_norm" ,
250- ):
97+ ) -> None :
25198 super ().__init__ ()
25299
253- self .attn = LiteMLA (
100+ self .attn = MultiscaleLinearAttention (
254101 in_channels = in_channels ,
255102 out_channels = in_channels ,
256103 heads_ratio = heads_ratio ,
0 commit comments