Skip to content

Commit ed81369

Browse files
【GPT-OSS】update sliding_attention layer use flashmask (#2606)
1 parent 7ca795a commit ed81369

File tree

9 files changed

+203
-196
lines changed

9 files changed

+203
-196
lines changed

paddleformers/generation/utils.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,35 @@
6363
]
6464

6565

66+
def _make_sliding_window_mask(input_shape, past_key_values_length=0, window_size=5):
67+
"""
68+
Generate a sliding window mask that restricts each position to only attend to historical positions within the window.
69+
Format: [bsz, 1, tgt_seq_len, src_seq_len], where True indicates allowed attention and False indicates masking.
70+
"""
71+
batch_size, seq_length = input_shape
72+
# Total sequence length = historical sequence length + current sequence length (for generating complete mask)
73+
total_length = past_key_values_length + seq_length
74+
75+
# Initialize mask with all False values
76+
mask = paddle.zeros((seq_length, total_length), dtype=paddle.bool)
77+
78+
for i in range(seq_length):
79+
# Absolute position of current location in the total sequence (including historical sequence)
80+
current_pos = past_key_values_length + i
81+
# Window start position: max(0, current position - window size + 1)
82+
start = max(0, current_pos - window_size + 1)
83+
# Window end position: current position (causal mask restriction, cannot exceed self)
84+
end = current_pos + 1 # Slice is left closed and right open, so+1
85+
# Mark window range as True (allow attention)
86+
mask[i, start:end] = True
87+
88+
# Expand dimensions to [bsz, 1, tgt_seq_len, src_seq_len]
89+
mask = mask.unsqueeze(0).unsqueeze(0)
90+
# Copy to each sample in batch_size
91+
mask = paddle.tile(mask, repeat_times=[batch_size, 1, 1, 1])
92+
return mask
93+
94+
6695
def get_unfinished_flag(
6796
input_ids: Tensor, unfinished_flag: Tensor, eos_token_id: Union[int, list[int], list[list[int]]]
6897
) -> Tensor:
@@ -354,29 +383,53 @@ def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id)
354383
return attention_mask
355384

356385
@staticmethod
357-
def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype):
386+
def _prepare_decoder_attention_mask(
387+
attention_mask, input_shape, past_key_values_length, dtype, sliding_window_size=None
388+
):
389+
# Step 1: Process input mask to generate basic expanded mask
358390
if attention_mask is not None:
359391
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
360392
if len(attention_mask.shape) == 2:
361393
expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1])
362-
# For decoding phase in generation, seq_length = 1, we don't need to add causal mask
394+
# When not generating in single step, need to combine causal mask and sliding window mask
363395
if input_shape[-1] > 1:
364-
combined_attention_mask = _make_causal_mask(
365-
input_shape, past_key_values_length=past_key_values_length
366-
)
396+
# Generate basic causal mask (prevent future information leakage)
397+
causal_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
398+
# Generate sliding window mask (limit historical attention range)
399+
if sliding_window_size is not None and sliding_window_size > 0:
400+
window_mask = _make_sliding_window_mask(
401+
input_shape, past_key_values_length=past_key_values_length, window_size=sliding_window_size
402+
)
403+
# Take intersection of sliding window mask and causal mask (satisfy both restrictions)
404+
combined_attention_mask = causal_mask & window_mask
405+
else:
406+
combined_attention_mask = (
407+
causal_mask # Use causal mask directly when sliding window is disabled
408+
)
409+
410+
# Combine with user-provided mask (e.g., padding mask)
367411
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
368412
expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool")
369413
else:
370414
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
371415
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
372416
elif len(attention_mask.shape) == 3:
373417
expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
374-
# if attention_mask is already 4-D, do nothing
418+
# 4D mask is used directly
375419
else:
376420
expanded_attn_mask = attention_mask
377421
else:
378-
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
379-
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
422+
# When no input mask, generate causal mask + sliding window mask (if enabled)
423+
causal_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
424+
if sliding_window_size is not None and sliding_window_size > 0:
425+
window_mask = _make_sliding_window_mask(
426+
input_shape, past_key_values_length=past_key_values_length, window_size=sliding_window_size
427+
)
428+
expanded_attn_mask = causal_mask & window_mask
429+
else:
430+
expanded_attn_mask = causal_mask # Use causal mask directly when sliding window is disabled
431+
432+
# Step 2: Convert boolean mask to numerical mask (adapt to different devices)
380433
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
381434
x = paddle.to_tensor(0.0, dtype="float32")
382435
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")

paddleformers/nn/attention/eager_attention.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def eager_attention_forward(
2727
value: paddle.Tensor,
2828
attention_mask: Optional[paddle.Tensor] = None,
2929
dropout: float = 0.0,
30+
sink: Optional[paddle.Tensor] = None,
3031
scaling: Optional[float] = None,
3132
is_causal: Optional[bool] = None,
3233
**kwargs,
@@ -45,8 +46,16 @@ def eager_attention_forward(
4546
if attention_mask is not None:
4647
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
4748
attn_weights = attn_weights + causal_mask
48-
attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype=paddle.float32).astype(query.dtype)
49-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
49+
50+
if sink is not None:
51+
sink = sink.reshape([1, -1, 1, 1]).expand([query.shape[0], -1, query.shape[-2], -1])
52+
combined_logits = paddle.concat([attn_weights, sink], axis=-1)
53+
probs = nn.functional.softmax(combined_logits, axis=-1, dtype=combined_logits.dtype)
54+
scores = probs[..., :-1] # we drop the sink here
55+
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
56+
else:
57+
attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype=paddle.float32).astype(query.dtype)
58+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
5059

5160
attn_output = paddle.matmul(attn_weights, value) # b h l l @ b h l d -> b h l d
5261
attn_output = attn_output.transpose([0, 2, 1, 3]) # b h l d -> b l h d

paddleformers/nn/attention/flashmask_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def flashmask_attention_forward(
3636
# b,l,h,d
3737
if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.ndim == 3:
3838
attn_mask_startend_row_indices = attn_mask_startend_row_indices.unsqueeze(-1)
39-
39+
if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.shape[-1] == 1:
40+
is_causal = True
4041
if sink is None:
4142
out = flashmask_attention(
4243
query,
@@ -54,7 +55,7 @@ def flashmask_attention_forward(
5455
startend_row_indices=attn_mask_startend_row_indices,
5556
dropout_p=dropout,
5657
softmax_scale=scaling,
57-
causal=is_causal,
58+
causal=is_causal if is_causal is not None else False,
5859
)
5960
out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
6061

paddleformers/nn/attention/sink_impl.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -417,10 +417,10 @@ def backward(ctx, grad_output):
417417
value_states,
418418
raw_output,
419419
lse_original,
420-
dropout,
421-
attention_mask,
422-
causal,
423-
scale,
420+
dropout=dropout,
421+
attention_mask=attention_mask,
422+
causal=causal,
423+
softmax_scale=scale,
424424
)
425425
else:
426426
grad_q_main, grad_k_repeated, grad_v_repeated = _flashmask_attention_backward_dispatch(
@@ -477,7 +477,16 @@ def backward(ctx, grad_output):
477477
)
478478
x = (g_ell.unsqueeze(-1) * query).to(query.dtype)
479479
_, grad_k_extra_repeated, _ = _flash_attention_backward_dispatch(
480-
x, query, key_states, key_states, mu_k, lse_k, dropout, causal, scale
480+
x,
481+
query,
482+
key_states,
483+
key_states,
484+
mu_k,
485+
lse_k,
486+
dropout=dropout,
487+
attention_mask=attention_mask,
488+
causal=causal,
489+
softmax_scale=scale,
481490
)
482491
else:
483492
# Use FlashMask for computing mu_k
@@ -511,12 +520,23 @@ def backward(ctx, grad_output):
511520
# Combine main and extra gradients
512521
grad_q = grad_q_main + grad_q_extra
513522
grad_k = grad_k_main + grad_k_extra
514-
515-
# Return gradients (number of return values must match forward inputs)
516-
if startend_row_indices is None:
517-
return grad_q, grad_k, grad_v, grad_sink
523+
if query.dtype != grad_q.dtype:
524+
grad_q = grad_q.cast(query.dtype)
525+
if key.dtype != grad_k.dtype:
526+
grad_k = grad_k.cast(key.dtype)
527+
if value.dtype != grad_v.dtype:
528+
grad_v = grad_v.cast(value.dtype)
529+
if sink.stop_gradient:
530+
# Return gradients (number of return values must match forward inputs)
531+
if startend_row_indices is None:
532+
return grad_q, grad_k, grad_v, None # grad_sink
533+
else:
534+
return grad_q, grad_k, grad_v, None, None
518535
else:
519-
return grad_q, grad_k, grad_v, grad_sink, None
536+
if startend_row_indices is None:
537+
return grad_q, grad_k, grad_v, grad_sink
538+
else:
539+
return grad_q, grad_k, grad_v, grad_sink, None
520540

521541

522542
def sink_attention_forward(

paddleformers/nn/pp_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,8 @@ def __init__(self, config: PretrainedConfig, **kwargs):
551551
EmbeddingPipe,
552552
shared_weight_attr="embedding_weight",
553553
config=config,
554+
embed_cls=self._embed_cls,
555+
rotary_emb_cls=self._rotary_emb_cls,
554556
),
555557
"model",
556558
)

paddleformers/transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@
175175
"ernie4_5_moe.modeling": ["Ernie4_5_MoeModel", "Ernie4_5_MoeForCausalLM", "Ernie4_5_MoeForCausalLMPipe"],
176176
"export": ["export_model"],
177177
"gpt_oss.configuration": ["GptOssConfig"],
178-
"gpt_oss.modeling": ["GptOssModel", "GptOssForCausalLM"],
178+
"gpt_oss.modeling": ["GptOssModel", "GptOssForCausalLM", "GptOssForCausalLMPipe"],
179179
"llama.configuration": [
180180
"LLAMA_PRETRAINED_INIT_CONFIGURATION",
181181
"LlamaConfig",

paddleformers/transformers/gpt_oss/__init__.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,27 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import sys
15+
from typing import TYPE_CHECKING
1416

15-
from .configuration import *
16-
from .modeling import *
17+
from ...utils.lazy_import import _LazyModule
18+
19+
import_structure = {
20+
"configuration": ["GptOssConfig"],
21+
"modeling": [
22+
"GptOssModel",
23+
"GptOssPretrainedModel",
24+
"GptOssForCausalLM",
25+
"GptOssForCausalLMPipe",
26+
],
27+
}
28+
if TYPE_CHECKING:
29+
from .configuration import *
30+
from .modeling import *
31+
else:
32+
sys.modules[__name__] = _LazyModule(
33+
__name__,
34+
globals()["__file__"],
35+
import_structure,
36+
module_spec=__spec__,
37+
)

paddleformers/transformers/gpt_oss/configuration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
hidden_act: str = "silu",
4343
initializer_range: float = 0.02,
4444
max_position_embeddings=131072,
45+
use_rmsnorm=True,
4546
rms_norm_eps: float = 1e-5,
4647
rope_scaling={"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False},
4748
attention_dropout: float = 0.0,
@@ -50,6 +51,7 @@ def __init__(
5051
output_router_logits=False,
5152
use_cache=True,
5253
layer_types=None,
54+
pp_seg_method="layer:GptOssDecoderLayer",
5355
**kwargs,
5456
):
5557
self.vocab_size = vocab_size
@@ -67,6 +69,7 @@ def __init__(
6769
self.num_key_value_heads = num_key_value_heads
6870
self.hidden_act = hidden_act
6971
self.initializer_range = initializer_range
72+
self.use_rmsnorm = use_rmsnorm
7073
self.rms_norm_eps = rms_norm_eps
7174
self.rope_theta = rope_theta
7275
self.rope_scaling = rope_scaling
@@ -91,6 +94,7 @@ def __init__(
9194
self.output_router_logits = output_router_logits
9295
self.use_cache = use_cache
9396
self.use_bias = False
97+
self.pp_seg_method = pp_seg_method
9498

9599
super().__init__(
96100
tie_word_embeddings=tie_word_embeddings,

0 commit comments

Comments
 (0)