We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 16b0045 commit 1580c79Copy full SHA for 1580c79
scripts/asr/owsm.py
@@ -53,6 +53,17 @@
53
s2t.dtype = "float32"
54
s2t.device = DEVICE
55
56
+ # NOTE: torch==2.8.0 doesn't support mps without patching torch.nn.Linear's foward method to convert the input to contiguous, fixed in 2.9.0
57
+ if torch.__version__ == "2.8.0":
58
+ from torch.nn import Linear
59
+
60
+ def forward(self, input):
61
+ return torch.nn.functional.linear(
62
+ input.contiguous(), self.weight, self.bias
63
+ )
64
65
+ Linear.forward = forward
66
67
68
def _naive_decode_long(wav_array, config, chunk_size=30 * TARGET_SAMPLE_RATE):
69
predictions = []
0 commit comments