Skip to content

Commit f0f695a

Browse files
authored
API Update: Modify concat and equal usage (#2652)
1 parent dc9ac0e commit f0f695a

File tree

80 files changed

+297
-303
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+297
-303
lines changed

examples/experiments/deepseek_v3_pretrain/load_hf_ckpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def _handle_mlp_weights(hf_prefix: str, rest: str) -> Optional[List[str]]:
257257

258258
def prepare_tensor(tensor, dst_shape, *, force_transpose=False):
259259
if isinstance(tensor, list):
260-
t = paddle.concat(
260+
t = paddle.cat(
261261
[
262262
paddle.transpose(tensor[0], perm=[1, 0]).contiguous(),
263263
paddle.transpose(tensor[1], perm=[1, 0]).contiguous(),

examples/experiments/deepseek_v3_pretrain/modeling.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -813,14 +813,14 @@ def forward(
813813
sin = sin[None, :, None, :]
814814
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids, self.fuse_rope)
815815

816-
query_states = paddle.concat([q_nope, q_pe], axis=-1)
817-
key_states = paddle.concat([k_nope, k_pe], axis=-1)
816+
query_states = paddle.cat([q_nope, q_pe], axis=-1)
817+
key_states = paddle.cat([k_nope, k_pe], axis=-1)
818818

819819
# [bs, seq_len, num_head, head_dim]
820820
if past_key_value is not None:
821821
# reuse k, v, self_attention
822-
key_states = paddle.concat([past_key_value[0], key_states], axis=1)
823-
value_states = paddle.concat([past_key_value[1], value_states], axis=1)
822+
key_states = paddle.cat([past_key_value[0], key_states], axis=1)
823+
value_states = paddle.cat([past_key_value[1], value_states], axis=1)
824824
past_key_value = (key_states, value_states) if use_cache else None
825825

826826
has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient)
@@ -1141,7 +1141,7 @@ def forward(
11411141
hidden_states = self.hnorm(hidden_states)
11421142
nextn_hidden_state = self.enorm(nextn_hidden_state)
11431143

1144-
concat_h = paddle.concat([nextn_hidden_state, hidden_states], axis=-1)
1144+
concat_h = paddle.cat([nextn_hidden_state, hidden_states], axis=-1)
11451145
hidden_states = FP8LinearFunction.apply(concat_h, self.eh_proj)
11461146

11471147
layer_outputs = super(DeepseekV2MTPLayer, self).forward(
@@ -1686,7 +1686,7 @@ def forward(
16861686
hidden_states = GatherOp.apply(hidden_states)
16871687
hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]])
16881688

1689-
inputs_embeds_cur_depth = paddle.concat(
1689+
inputs_embeds_cur_depth = paddle.cat(
16901690
[inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1
16911691
)
16921692

@@ -1848,7 +1848,7 @@ def _set_cos_sin_cache(self, seq_len):
18481848
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
18491849
)
18501850

1851-
emb = paddle.concat((freqs, freqs), axis=-1)
1851+
emb = paddle.cat((freqs, freqs), axis=-1)
18521852
self.cos_cached = emb.cos() * _mscale
18531853
self.sin_cached = emb.sin() * _mscale
18541854

@@ -1919,7 +1919,7 @@ def _set_cos_sin_cache(self, seq_len):
19191919
freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
19201920
# Different from paper, but it uses a different permutation in order to obtain the same calculation
19211921
# [seq_len, axis]
1922-
emb = paddle.concat([freqs, freqs], axis=-1)
1922+
emb = paddle.cat([freqs, freqs], axis=-1)
19231923
# [1, seqlen, 1, axis]
19241924
self.cos_cached = emb.cos()[None, :, None, :]
19251925
self.sin_cached = emb.sin()[None, :, None, :]
@@ -2137,8 +2137,8 @@ def qkv_pre_process_no_fuse(
21372137
sin = sin[None, :, None, :]
21382138
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids, False)
21392139

2140-
query_states = paddle.concat([q_nope, q_pe], axis=-1)
2141-
key_states = paddle.concat([k_nope, k_pe], axis=-1)
2140+
query_states = paddle.cat([q_nope, q_pe], axis=-1)
2141+
key_states = paddle.cat([k_nope, k_pe], axis=-1)
21422142

21432143
return query_states, key_states, value_states
21442144

@@ -2149,7 +2149,7 @@ def rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads):
21492149
value_states = kv[..., qk_nope_head_dim:]
21502150

21512151
k_pe = k_pe.expand([k_pe.shape[0], k_pe.shape[1], num_heads, k_pe.shape[3]])
2152-
key_states = paddle.concat([k_nope, k_pe], axis=-1)
2152+
key_states = paddle.cat([k_nope, k_pe], axis=-1)
21532153

21542154
return key_states, value_states
21552155

@@ -2315,7 +2315,7 @@ def forward(
23152315
[bsz, kv_seq_len, v_num_heads, q_head_dim - v_head_dim],
23162316
dtype=value_states.dtype,
23172317
)
2318-
value_states_pad = paddle.concat([value_states, value_padding], axis=-1)
2318+
value_states_pad = paddle.cat([value_states, value_padding], axis=-1)
23192319

23202320
attn_out, _, softmax_lse, seed_offset = _C_ops.flash_attn(
23212321
query_states,
@@ -2541,7 +2541,7 @@ def backward(ctx, dout):
25412541
[bsz, kv_seq_len, v_num_heads, q_head_dim - v_head_dim],
25422542
dtype=value_states.dtype,
25432543
)
2544-
value_states_pad = paddle.concat([value_states, value_padding], axis=-1)
2544+
value_states_pad = paddle.cat([value_states, value_padding], axis=-1)
25452545

25462546
with paddle.no_grad():
25472547

@@ -2655,7 +2655,7 @@ def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_sca
26552655
compressed_kv, kv_ln_weight, kv_ln_invar, d_kv_ln_t, eps
26562656
)
26572657

2658-
d_kv_init = paddle.concat([d_compressed_kv, d_k_pe], axis=-1)
2658+
d_kv_init = paddle.cat([d_compressed_kv, d_k_pe], axis=-1)
26592659

26602660
if hasattr(q_up_weight, "main_grad"):
26612661

examples/experiments/deepseek_v3_pretrain/modeling_pp.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def forward_without_residual(self, inputs):
205205

206206
if self.send_mtp_embed:
207207
assert not self.output_mtp_embed_first, "forward_without_residual doesn't support output_mtp_embed_first"
208-
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
208+
hidden_states = paddle.cat([hidden_states, inputs_embeds_mtp], axis=-1)
209209
self.mtp_embed_shape = (
210210
inputs_embeds_mtp.shape
211211
) # Save the shape of mtp_embed, used for backward propagation
@@ -248,9 +248,9 @@ def forward(self, inputs):
248248

249249
if self.send_mtp_embed:
250250
if self.output_mtp_embed_first:
251-
hidden_states = paddle.concat([inputs_embeds_mtp, hidden_states], axis=-1)
251+
hidden_states = paddle.cat([inputs_embeds_mtp, hidden_states], axis=-1)
252252
else:
253-
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
253+
hidden_states = paddle.cat([hidden_states, inputs_embeds_mtp], axis=-1)
254254
self.mtp_embed_shape = (
255255
inputs_embeds_mtp.shape
256256
) # Save the shape of mtp_embed shape, used for backward propagation
@@ -1501,7 +1501,7 @@ def forward(self, args):
15011501
embeds_res = [inputs_embeds]
15021502
mtp_embeds = []
15031503
for depth in range(self.config.num_nextn_predict_layers):
1504-
inputs_embeds_mtp = paddle.concat(
1504+
inputs_embeds_mtp = paddle.cat(
15051505
[
15061506
inputs_embeds_ori[:, (depth + 1) :, :],
15071507
inputs_embeds_extra[:, : (depth + 1), :],
@@ -1519,7 +1519,7 @@ def forward(self, args):
15191519
# mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size]
15201520
# else:
15211521
# mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size]
1522-
inputs_embeds = paddle.concat(embeds_res, axis=-1)
1522+
inputs_embeds = paddle.cat(embeds_res, axis=-1)
15231523
else:
15241524
global global_inputs_embeds_mtp_queue
15251525
cloned_mtp_embeds = [t.detach() for t in mtp_embeds]
@@ -1586,7 +1586,7 @@ def forward(self, args):
15861586
)
15871587

15881588
if self.config.send_mtp_embed:
1589-
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
1589+
hidden_states = paddle.cat([hidden_states, inputs_embeds_mtp], axis=-1)
15901590

15911591
return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)
15921592

@@ -1727,7 +1727,7 @@ def post_process_compute(self, inputs):
17271727
l_aux,
17281728
)
17291729
if send_mtp_embed:
1730-
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
1730+
hidden_states = paddle.cat([hidden_states, inputs_embeds_mtp], axis=-1)
17311731

17321732
return return_args(hidden_states)
17331733

@@ -1752,7 +1752,7 @@ def post_process_compute_for_fusion(self, inputs):
17521752
hidden_states = hidden_states[0]
17531753

17541754
if send_mtp_embed:
1755-
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
1755+
hidden_states = paddle.cat([hidden_states, inputs_embeds_mtp], axis=-1)
17561756

17571757
return return_args(hidden_states)
17581758

@@ -1784,7 +1784,7 @@ def mlp_compute_dense(self, inputs):
17841784
hidden_states = residual + hidden_states
17851785

17861786
if self.config.send_mtp_embed:
1787-
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
1787+
hidden_states = paddle.cat([hidden_states, inputs_embeds_mtp], axis=-1)
17881788

17891789
return hidden_states
17901790

@@ -1915,7 +1915,7 @@ def forward(self, args):
19151915
)
19161916
output_list.append(hidden_states)
19171917

1918-
hidden_states = paddle.concat(output_list, axis=-1)
1918+
hidden_states = paddle.cat(output_list, axis=-1)
19191919
return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)
19201920

19211921
def attn_compute_for_fusion(self, args):
@@ -1941,7 +1941,7 @@ def attn_compute_for_fusion(self, args):
19411941
hidden_states = self.hnorm(hidden_states)
19421942
nextn_hidden_state = self.enorm(nextn_hidden_state)
19431943

1944-
concat_h = paddle.concat([nextn_hidden_state, hidden_states], axis=-1)
1944+
concat_h = paddle.cat([nextn_hidden_state, hidden_states], axis=-1)
19451945
hidden_states = FP8LinearFunction.apply(concat_h, self.eh_proj)
19461946

19471947
# attention compute

examples/experiments/deepseek_v3_pretrain/moe_gate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def top2gating(
269269
mask2 = self._one_hot_to_int64(indices2_s, self.num_experts) # [S, E]
270270

271271
# Note: mask1 and mask2 can be combined to form a single mask.
272-
# mask = paddle.concat([mask1, mask2], axis=0)
272+
# mask = paddle.cat([mask1, mask2], axis=0)
273273
# locations = paddle.cumsum(mask, axis=0) - 1
274274
# locations1, locations2 = locations.split(2, axis=0)
275275
# Compute locations in capacity buffer.

examples/experiments/deepseek_v3_pretrain/moe_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def forward_drop_token(
294294
expert_out = expert(tokens_for_this_expert)
295295
outputs.append(expert_out)
296296
start_idx = end_idx
297-
outs = paddle.concat(outputs, axis=0) if len(outputs) > 0 else paddle.to_tensor(0, dtype=sorted_tokens.dtype)
297+
outs = paddle.cat(outputs, axis=0) if len(outputs) > 0 else paddle.to_tensor(0, dtype=sorted_tokens.dtype)
298298
if self.expert_parallel_degree > 1:
299299
new_x = paddle.empty_like(outs)
300300
new_x[gatherd_idxs] = outs
@@ -349,7 +349,7 @@ def expert_forward(self, dispatched_input, tokens_per_expert):
349349
# assert chunk.shape[0] != 0, "Cannot dispatch empty input"
350350
outputs += [expert(chunk)]
351351

352-
return paddle.concat(outputs, axis=0)
352+
return paddle.cat(outputs, axis=0)
353353

354354
def forward(self, hidden_states: paddle.Tensor):
355355
_, _, d_model = hidden_states.shape

examples/experiments/deepseek_v3_pretrain/moe_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _holder_size(self):
5656

5757
def topk_to_permuted_indices(x, num_tokens_per_expert_list, topk):
5858
x = paddle.flatten(x)
59-
prob_permuted_indices = paddle.concat(
59+
prob_permuted_indices = paddle.cat(
6060
[
6161
paddle.tensor.search._restrict_nonzero(x == i, total_true_num)
6262
for i, total_true_num in enumerate(num_tokens_per_expert_list)

paddleformers/datasets/rlhf_datasets/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def concat(data: List["DataProto"]) -> "DataProto":
544544
for batch in data:
545545
batch_lst.append(batch.batch)
546546
if batch_lst[0] is not None:
547-
new_batch = paddle.concat(batch_lst, axis=0)
547+
new_batch = paddle.cat(batch_lst, axis=0)
548548
else:
549549
new_batch = None
550550

paddleformers/generation/utils.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -565,11 +565,11 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
565565
# update token_type_ids with last value
566566
if "token_type_ids" in model_kwargs and model_kwargs["token_type_ids"] is not None:
567567
token_type_ids = model_kwargs["token_type_ids"]
568-
model_kwargs["token_type_ids"] = paddle.concat([token_type_ids, token_type_ids[:, -1:]], axis=-1)
568+
model_kwargs["token_type_ids"] = paddle.cat([token_type_ids, token_type_ids[:, -1:]], axis=-1)
569569
if not is_encoder_decoder and model_kwargs.get("attention_mask", None) is not None:
570570
# update attention mask
571571
attention_mask = model_kwargs["attention_mask"]
572-
model_kwargs["attention_mask"] = paddle.concat(
572+
model_kwargs["attention_mask"] = paddle.cat(
573573
[
574574
attention_mask,
575575
paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype),
@@ -579,7 +579,7 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
579579
# update role_ids
580580
if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
581581
role_ids = model_kwargs["role_ids"]
582-
model_kwargs["role_ids"] = paddle.concat([role_ids, role_ids[:, -1:]], axis=-1)
582+
model_kwargs["role_ids"] = paddle.cat([role_ids, role_ids[:, -1:]], axis=-1)
583583

584584
return model_kwargs
585585

@@ -1235,7 +1235,7 @@ def greedy_search(
12351235
scores = self.update_scores_for_generation(scores, next_scores, cur_len - origin_len, unfinished_flag)
12361236
cur_len += 1
12371237

1238-
input_ids = paddle.concat([input_ids, next_tokens], axis=1)
1238+
input_ids = paddle.cat([input_ids, next_tokens], axis=1)
12391239
if streamer is not None:
12401240
if self.config.tensor_parallel_rank == 0:
12411241
streamer.put(next_tokens.cpu())
@@ -1379,7 +1379,7 @@ def sample(
13791379
scores = self.update_scores_for_generation(scores, next_scores, cur_len - origin_len, unfinished_flag)
13801380

13811381
cur_len += 1
1382-
input_ids = paddle.concat([input_ids, next_tokens], axis=1)
1382+
input_ids = paddle.cat([input_ids, next_tokens], axis=1)
13831383
if streamer is not None:
13841384
if self.config.tensor_parallel_rank == 0:
13851385
streamer.put(next_tokens.cpu())
@@ -1550,7 +1550,7 @@ def _post_process_(
15501550
if eos_token_id is not None:
15511551
next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id))
15521552

1553-
input_ids = paddle.concat([input_ids, next_tokens], axis=1)
1553+
input_ids = paddle.cat([input_ids, next_tokens], axis=1)
15541554

15551555
if eos_token_id is not None:
15561556
unfinished_flag = get_unfinished_flag(input_ids, unfinished_flag, eos_token_id)
@@ -1729,9 +1729,7 @@ def beam_search(
17291729
beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
17301730

17311731
cur_len += 1
1732-
input_ids = paddle.concat(
1733-
[paddle.index_select(input_ids, beam_idx), beam_next_tokens.unsqueeze(-1)], axis=-1
1734-
)
1732+
input_ids = paddle.cat([paddle.index_select(input_ids, beam_idx), beam_next_tokens.unsqueeze(-1)], axis=-1)
17351733

17361734
if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
17371735
if not synced_gpus:
@@ -1893,7 +1891,7 @@ def group_beam_search(
18931891
beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
18941892

18951893
input_ids[batch_group_indices] = group_input_ids[beam_idx]
1896-
group_input_ids = paddle.concat(
1894+
group_input_ids = paddle.cat(
18971895
[paddle.index_select(group_input_ids, index=beam_idx), beam_next_tokens.unsqueeze(-1)], axis=-1
18981896
)
18991897
current_tokens[batch_group_indices] = beam_next_tokens
@@ -1902,7 +1900,7 @@ def group_beam_search(
19021900
num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size)
19031901
)
19041902

1905-
input_ids = paddle.concat([input_ids, current_tokens.unsqueeze(-1)], axis=-1)
1903+
input_ids = paddle.cat([input_ids, current_tokens.unsqueeze(-1)], axis=-1)
19061904

19071905
cur_len += 1
19081906

paddleformers/nn/attention/eager_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def eager_attention_forward(
4949

5050
if sink is not None:
5151
sink = sink.reshape([1, -1, 1, 1]).expand([query.shape[0], -1, query.shape[-2], -1])
52-
combined_logits = paddle.concat([attn_weights, sink], axis=-1)
52+
combined_logits = paddle.cat([attn_weights, sink], axis=-1)
5353
probs = nn.functional.softmax(combined_logits, axis=-1, dtype=combined_logits.dtype)
5454
scores = probs[..., :-1] # we drop the sink here
5555
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)

paddleformers/nn/criterion/dpo_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def cal_dpo_loss(
241241
rejected_logratios = policy_rejected_logps - reference_rejected_logps
242242
# As described in the KTO report, the KL term for chosen (rejected) is
243243
# estimated using the rejected (chosen) half.
244-
loss = paddle.concat(
244+
loss = paddle.cat(
245245
(
246246
1 - F.sigmoid(self.dpo_config.beta * (chosen_logratios - rejected_KL)),
247247
1 - F.sigmoid(self.dpo_config.beta * (chosen_KL - rejected_logratios)),

0 commit comments

Comments
 (0)