@@ -759,7 +759,7 @@ def forward(
759
759
760
760
bsz , seq_len , dim = x .shape
761
761
if self .use_conv2d :
762
- x = x .reshape (bsz , seq_len , 1 , dim ).transpose (1 , 3 )
762
+ x = x .reshape (bsz , - 1 , 1 , dim ).transpose (1 , 3 )
763
763
764
764
new_qs = [wq (x ) for wq in self .wqs ]
765
765
new_ks = [wk (x ) for wk in self .wks ]
@@ -768,9 +768,7 @@ def forward(
768
768
if self .use_conv2d :
769
769
770
770
def from_conv2ds (ts ):
771
- return [
772
- t .reshape (bsz , self .head_dim , seq_len ).transpose (1 , 2 ) for t in ts
773
- ]
771
+ return [t .reshape (bsz , self .head_dim , - 1 ).transpose (1 , 2 ) for t in ts ]
774
772
775
773
new_qs = from_conv2ds (new_qs )
776
774
new_ks = from_conv2ds (new_ks )
@@ -800,9 +798,11 @@ def from_conv2ds(ts):
800
798
801
799
if self .use_conv2d :
802
800
y = (
803
- self .wo (y .reshape (bsz , seq_len , 1 , - 1 ).transpose (1 , 3 ))
801
+ self .wo (
802
+ y .reshape (bsz , - 1 , 1 , self .n_heads * self .head_dim ).transpose (1 , 3 )
803
+ )
804
804
.transpose (1 , 3 )
805
- .reshape (bsz , seq_len , - 1 )
805
+ .reshape (bsz , - 1 , self . dim )
806
806
)
807
807
else :
808
808
y = self .wo (y )
0 commit comments