Skip to content

Commit ab455e0

Browse files
authored
Audio encodings now match conv2d weight dtype in Gemma3nAudioSSCPConvBlock (#39743)
audio encodings now match conv weight dtype in Gemma3nAudioSSCPConvBlock
1 parent 4b3a1a6 commit ab455e0

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/transformers/models/gemma3n/modeling_gemma3n.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,9 @@ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
695695
# Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
696696
# manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
697697
# F.pad applies to last two dims: F_in then T_in
698-
audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0)
698+
audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0).to(
699+
self.conv.weight.dtype
700+
)
699701
# Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
700702
# Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
701703
audio_encodings_conv = self.conv(audio_encodings_padded)

src/transformers/models/gemma3n/modular_gemma3n.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1264,7 +1264,9 @@ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
12641264
# Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
12651265
# manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
12661266
# F.pad applies to last two dims: F_in then T_in
1267-
audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0)
1267+
audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0).to(
1268+
self.conv.weight.dtype
1269+
)
12681270
# Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
12691271
# Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
12701272
audio_encodings_conv = self.conv(audio_encodings_padded)

0 commit comments

Comments
 (0)