@@ -759,7 +759,7 @@ def forward(
759759
760760 bsz , seq_len , dim = x .shape
761761 if self .use_conv2d :
762- x = x .reshape (bsz , seq_len , 1 , dim ).transpose (1 , 3 )
762+ x = x .reshape (bsz , - 1 , 1 , dim ).transpose (1 , 3 )
763763
764764 new_qs = [wq (x ) for wq in self .wqs ]
765765 new_ks = [wk (x ) for wk in self .wks ]
@@ -768,9 +768,7 @@ def forward(
768768 if self .use_conv2d :
769769
770770 def from_conv2ds (ts ):
771- return [
772- t .reshape (bsz , self .head_dim , seq_len ).transpose (1 , 2 ) for t in ts
773- ]
771+ return [t .reshape (bsz , self .head_dim , - 1 ).transpose (1 , 2 ) for t in ts ]
774772
775773 new_qs = from_conv2ds (new_qs )
776774 new_ks = from_conv2ds (new_ks )
@@ -800,9 +798,11 @@ def from_conv2ds(ts):
800798
801799 if self .use_conv2d :
802800 y = (
803- self .wo (y .reshape (bsz , seq_len , 1 , - 1 ).transpose (1 , 3 ))
801+ self .wo (
802+ y .reshape (bsz , - 1 , 1 , self .n_heads * self .head_dim ).transpose (1 , 3 )
803+ )
804804 .transpose (1 , 3 )
805- .reshape (bsz , seq_len , - 1 )
805+ .reshape (bsz , - 1 , self . dim )
806806 )
807807 else :
808808 y = self .wo (y )
0 commit comments