Skip to content

Commit 604c08a

Browse files
authored
Support general design for modeling (#2446)
1 parent 6dfc06d commit 604c08a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+3125
-3702
lines changed

paddleformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
modules = [
5353
"data",
5454
"datasets",
55+
"nn",
5556
"mergekit",
5657
"ops",
5758
"peft",
@@ -69,6 +70,7 @@
6970
data,
7071
datasets,
7172
mergekit,
73+
nn,
7274
ops,
7375
peft,
7476
quantization,

paddleformers/generation/utils.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from ..transformers.model_outputs import ModelOutput
3030
from ..transformers.utils import get_scale_by_dtype
3131
from ..utils.log import logger
32+
from ..utils.masking_utils import _expand_2d_mask, _make_causal_mask
33+
from ..utils.tools import get_env_device
3234
from .configuration_utils import DEFAULT_MAX_NEW_TOKENS, GenerationConfig
3335
from .logits_process import (
3436
ForcedBOSTokenLogitsProcessor,
@@ -339,13 +341,61 @@ def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id)
339341
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
340342
(eos_token_id is not None) and (pad_token_id != eos_token_id)
341343
)
342-
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
343-
attention_mask = (input_ids == pad_token_id).astype(paddle.get_default_dtype()) * get_scale_by_dtype(
344-
return_positive=False
345-
)
344+
inputs_tensor = input_ids
345+
346+
# No information for attention mask inference -> return default attention mask
347+
default_attention_mask = paddle.ones(input_ids.shape[:2], dtype=paddle.get_default_dtype())
348+
if pad_token_id is None:
349+
return default_attention_mask
350+
can_infer_attention_mask = is_pad_token_in_inputs_ids * is_pad_token_not_equal_to_eos_token_id
351+
attention_mask_from_padding = (inputs_tensor != pad_token_id).astype(paddle.get_default_dtype())
352+
353+
attention_mask = (
354+
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
355+
)
356+
return attention_mask
357+
358+
@staticmethod
359+
def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype):
360+
if attention_mask is not None:
361+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
362+
if len(attention_mask.shape) == 2:
363+
expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1])
364+
# For decoding phase in generation, seq_length = 1, we don't need to add causal mask
365+
if input_shape[-1] > 1:
366+
combined_attention_mask = _make_causal_mask(
367+
input_shape, past_key_values_length=past_key_values_length
368+
)
369+
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
370+
expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool")
371+
else:
372+
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
373+
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
374+
elif len(attention_mask.shape) == 3:
375+
expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
376+
# if attention_mask is already 4-D, do nothing
377+
else:
378+
expanded_attn_mask = attention_mask
379+
else:
380+
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
381+
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
382+
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
383+
x = paddle.to_tensor(0.0, dtype="float32")
384+
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
385+
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)
386+
elif get_env_device() == "xpu":
387+
x = paddle.to_tensor(0.0, dtype="float32")
388+
y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
389+
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y)
390+
elif get_env_device() == "gcu":
391+
min_val = paddle.finfo(dtype).min
392+
x = paddle.to_tensor(0.0, dtype=dtype)
393+
y = paddle.to_tensor(min_val, dtype=dtype)
394+
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)
346395
else:
347-
attention_mask = paddle.zeros_like(input_ids, dtype=paddle.get_default_dtype())
348-
return paddle.unsqueeze(attention_mask, axis=[1, 2])
396+
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min)
397+
expanded_attn_mask = expanded_attn_mask.astype(dtype)
398+
return expanded_attn_mask
349399

350400
@staticmethod
351401
def prepare_seq_len_for_generation(input_ids, pad_token_id, eos_token_id):
@@ -853,12 +903,8 @@ def generate(
853903
bos_token_id, encoder_output=model_kwargs["inputs_embeds"]
854904
)
855905

856-
if model_kwargs.get("attention_mask", None) is None:
857-
# TODO
858-
# Init `attention_mask` depending on `pad_token_id`
859-
model_kwargs["attention_mask"] = self.prepare_attention_mask_for_generation(
860-
input_ids, pad_token_id, eos_token_id
861-
)
906+
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
907+
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
862908
self.is_encoder_decoder = self.config.is_encoder_decoder
863909

864910
if self.is_encoder_decoder:
@@ -880,6 +926,11 @@ def generate(
880926

881927
pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
882928

929+
if not kwargs_has_attention_mask and accepts_attention_mask:
930+
model_kwargs["attention_mask"] = self.prepare_attention_mask_for_generation(
931+
input_ids, pad_token_id, eos_token_id
932+
)
933+
883934
if generation_config.max_length != 0 and generation_config.max_new_tokens == DEFAULT_MAX_NEW_TOKENS:
884935
logger.warning("`max_length` will be deprecated in future releases, use `max_new_tokens` instead.")
885936
generation_config.max_new_tokens = generation_config.max_length

paddleformers/nn/__init__.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
from contextlib import suppress
17+
from typing import TYPE_CHECKING
18+
19+
from ..utils.lazy_import import _LazyModule
20+
21+
import_structure = {
22+
"attention": ["AttentionInterface", "ALL_ATTENTION_FUNCTIONS"],
23+
"criterion": ["LossInterface", "ALL_LOSS_FUNCTIONS", "CriterionLayer"],
24+
"attention.eager_attention": ["eager_attention_forward"],
25+
"attention.flashmask_attention": ["flashmask_attention_forward"],
26+
"attention.interface": ["AttentionInterface", "ALL_ATTENTION_FUNCTIONS"],
27+
"attention.sdpa_attention": ["sdpa_attention_forward"],
28+
"attention.utils": ["repeat_kv"],
29+
"criterion.dpo_loss": ["dpo_preprocess_inputs", "dpo_logps", "cal_dpo_loss", "dpo_loss_forward"],
30+
"criterion.interface": ["LossInterface", "ALL_LOSS_FUNCTIONS", "CriterionLayer"],
31+
"criterion.kto_loss": ["kto_preprocess_inputs", "_nested_gather", "kto_logps", "kto_loss", "kto_loss_forward"],
32+
"criterion.loss_utils": ["calc_lm_head_logits", "subbatch"],
33+
"criterion.sft_loss": [
34+
"sft_preprocess_inputs",
35+
"sft_postprocess_loss",
36+
"sft_loss_forward",
37+
],
38+
"activation": ["ACT2FN", "ClassInstantier", "ACT2CLS"],
39+
"embedding": ["Embedding"],
40+
"general": ["GeneralInterface"],
41+
"linear": ["Linear"],
42+
"lm_head": ["LMHead"],
43+
"mlp": ["MLP"],
44+
"norm": ["Norm", "LayerNorm", "RMSNorm"],
45+
}
46+
47+
if TYPE_CHECKING:
48+
from .activation import *
49+
from .attention import *
50+
from .criterion import *
51+
from .embedding import *
52+
from .general import *
53+
from .linear import *
54+
from .lm_head import *
55+
from .mlp import *
56+
from .norm import *
57+
else:
58+
sys.modules[__name__] = _LazyModule(
59+
__name__,
60+
globals()["__file__"],
61+
import_structure,
62+
module_spec=__spec__,
63+
)

paddleformers/nn/activation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import OrderedDict
16+
17+
import paddle.nn as nn
18+
19+
20+
class ClassInstantier(OrderedDict):
21+
def __getitem__(self, key):
22+
content = super().__getitem__(key)
23+
cls, kwargs = content if isinstance(content, tuple) else (content, {})
24+
return cls(**kwargs)
25+
26+
27+
ACT2CLS = {
28+
"relu": nn.ReLU,
29+
"relu6": nn.ReLU6,
30+
"sigmoid": nn.Sigmoid,
31+
"silu": nn.Silu,
32+
"tanh": nn.Tanh,
33+
"prelu": nn.PReLU,
34+
}
35+
36+
ACT2FN = ClassInstantier(ACT2CLS)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
from contextlib import suppress
17+
from typing import TYPE_CHECKING
18+
19+
from ...utils.lazy_import import _LazyModule
20+
21+
import_structure = {
22+
"eager_attention": ["eager_attention_forward"],
23+
"flashmask_attention": ["flashmask_attention_forward"],
24+
"interface": ["AttentionInterface", "ALL_ATTENTION_FUNCTIONS"],
25+
"sdpa_attention": ["sdpa_attention_forward"],
26+
"utils": ["repeat_kv"],
27+
}
28+
29+
if TYPE_CHECKING:
30+
from .eager_attention import *
31+
from .flashmask_attention import *
32+
from .interface import *
33+
from .sdpa_attention import *
34+
from .utils import *
35+
else:
36+
sys.modules[__name__] = _LazyModule(
37+
__name__,
38+
globals()["__file__"],
39+
import_structure,
40+
module_spec=__spec__,
41+
)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import paddle
18+
import paddle.nn as nn
19+
20+
from .utils import repeat_kv
21+
22+
23+
def eager_attention_forward(
24+
module: nn.Layer,
25+
query: paddle.Tensor,
26+
key: paddle.Tensor,
27+
value: paddle.Tensor,
28+
attention_mask: Optional[paddle.Tensor] = None,
29+
dropout: float = 0.0,
30+
scaling: Optional[float] = None,
31+
is_causal: Optional[bool] = None,
32+
**kwargs,
33+
):
34+
num_key_value_heads = None
35+
if hasattr(module, "num_key_value_heads"):
36+
num_key_value_heads = module.num_key_value_heads
37+
elif hasattr(module, "num_key_value_groups"):
38+
num_key_value_heads = module.num_key_value_groups
39+
40+
if num_key_value_heads is not None:
41+
key = repeat_kv(key, module.num_key_value_heads)
42+
value = repeat_kv(value, module.num_key_value_heads)
43+
44+
perm = [0, 2, 1, 3] # b l h d -> b h l d
45+
query = paddle.transpose(x=query, perm=perm)
46+
key = paddle.transpose(x=key, perm=perm)
47+
value = paddle.transpose(x=value, perm=perm)
48+
attn_weights = paddle.matmul(query, key.transpose([0, 1, 3, 2])) * scaling
49+
if attention_mask is not None:
50+
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
51+
attn_weights = attn_weights + causal_mask
52+
attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype=paddle.float32).astype(query.dtype)
53+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
54+
55+
attn_output = paddle.matmul(attn_weights, value) # b h l l @ b h l d -> b h l d
56+
attn_output = attn_output.transpose([0, 2, 1, 3]) # b h l d -> b l h d
57+
attn_output = paddle.reshape(x=attn_output, shape=[0, 0, attn_output.shape[2] * attn_output.shape[3]])
58+
59+
return attn_output, attn_weights
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import paddle
18+
import paddle.nn as nn
19+
from paddle.nn.functional.flash_attention import flashmask_attention
20+
21+
22+
def flashmask_attention_forward(
23+
module: nn.Layer,
24+
query: paddle.Tensor,
25+
key: paddle.Tensor,
26+
value: paddle.Tensor,
27+
attention_mask: Optional[paddle.Tensor] = None,
28+
attn_mask_start_row_indices=None,
29+
dropout: float = 0.0,
30+
scaling: Optional[float] = None,
31+
is_causal: Optional[bool] = None,
32+
**kwargs
33+
):
34+
if attn_mask_start_row_indices is not None:
35+
attn_mask_start_row_indices = attn_mask_start_row_indices.unsqueeze(-1)
36+
37+
# b,l,h,d
38+
out = flashmask_attention(
39+
query,
40+
key,
41+
value,
42+
startend_row_indices=attn_mask_start_row_indices,
43+
causal=True,
44+
)
45+
out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
46+
47+
return out, None

0 commit comments

Comments
 (0)