Skip to content

Commit 1580c79

Browse files
committed
patch owsm for torch 2.8.0
1 parent 16b0045 commit 1580c79

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

scripts/asr/owsm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@
5353
s2t.dtype = "float32"
5454
s2t.device = DEVICE
5555

56+
# NOTE: torch==2.8.0 doesn't support mps without patching torch.nn.Linear's foward method to convert the input to contiguous, fixed in 2.9.0
57+
if torch.__version__ == "2.8.0":
58+
from torch.nn import Linear
59+
60+
def forward(self, input):
61+
return torch.nn.functional.linear(
62+
input.contiguous(), self.weight, self.bias
63+
)
64+
65+
Linear.forward = forward
66+
5667

5768
def _naive_decode_long(wav_array, config, chunk_size=30 * TARGET_SAMPLE_RATE):
5869
predictions = []

0 commit comments

Comments
 (0)