Skip to content

Conversation

@Xreki
Copy link
Collaborator

@Xreki Xreki commented Dec 29, 2025

PR Category

Bug Fix

Description

修复naive_call_method_expand_pass代码中一处typo,因为该typo,_node_need_rewrite会返回False,导致很多expand算子的维度参数无法被正确地rewrite

samples/transformers-auto-model/opus-mt-en-gv样本为例,原样本对应代码为:

        getitem = l_attention_mask_[                                                                                                                                                                       
            (slice(None, None, None), None, None, slice(None, None, None))
        ]    
        expand = getitem.expand(1, 1, 29, 29)

develop代码维度泛化后:

        getitem = L_attention_mask_[(slice(None, None, None), None, None, slice(None, None, None))]
        size_2 = getitem.size(3)
        size_3 = getitem.size(3)
        expand = getitem.expand(1, 1, size_2, size_3);  getitem = size_2 = size_3 = None

实际上L_attention_mask_.size()=torch.Size([128, 64])getitem.size()=torch.Size([128, 1, 1, 64])expand算子因维度不匹配而报错。

本PR修复后,维度泛化后的代码为:

getitem = L_attention_mask_[(slice(None, None, None), None, None, slice(None, None, None))]        
size_2 = getitem.size(0)
size_3 = getitem.size(3)
size_4 = getitem.size(3)
expand = getitem.expand(size_2, 1, size_4, size_3);  getitem = size_2 = size_4 = size_3 = None 

@lixinqi lixinqi merged commit 16ac11d into PaddlePaddle:develop Dec 29, 2025
3 checks passed
@Xreki Xreki deleted the fix_dim_expand_pass branch December 29, 2025 08:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants