@@ -136,9 +136,9 @@ def __call__(
136
136
if use_mask :
137
137
# Special case requires building a mask. `mask_cache` is only needed
138
138
# then.
139
- assert (
140
- self . mask_cache is not None
141
- ), "mask_cache must be given if sliding window attention is used, or if input_pos given and T > 1"
139
+ assert self . mask_cache is not None , (
140
+ " mask_cache must be given if sliding window attention is used, or if input_pos given and T > 1"
141
+ )
142
142
if is_causal :
143
143
mask = self .mask_cache [:T , :T ].view (1 , 1 , T , T )
144
144
is_causal = False
@@ -156,9 +156,7 @@ def __call__(
156
156
nh_k = self .config .n_query_groups
157
157
q_per_kv = nh_q // nh_k
158
158
if q_per_kv > 1 :
159
- mask = mask .unsqueeze (2 ).expand (
160
- - 1 , - 1 , q_per_kv , - 1 , - 1
161
- ).reshape (B , nh_q , T , - 1 )
159
+ mask = mask .unsqueeze (2 ).expand (- 1 , - 1 , q_per_kv , - 1 , - 1 ).reshape (B , nh_q , T , - 1 )
162
160
163
161
# Efficient attention using Flash Attention CUDA kernels.
164
162
# NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled.
@@ -236,7 +234,8 @@ def scaled_dot_product_attention(
236
234
# `torch.nn.functional.scaled_dot_product_attention`
237
235
sqrt_scale = math .sqrt (scale )
238
236
scores = _attention_compute_scores (
239
- sqrt_scale * query , sqrt_scale * key ,
237
+ sqrt_scale * query ,
238
+ sqrt_scale * key ,
240
239
)
241
240
scores = do_softcapping (scores , attention_logit_softcapping )
242
241
if mask is None and is_causal :
0 commit comments