Skip to content

Commit c176603

Browse files
authored
Merge pull request #21 from audiohacking/copilot/fix-decoding-audio-error-again
Fix MPS tensor dtype in audio generation pipeline
2 parents b41ca42 + 66cef0f commit c176603

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

backend/app/services/music_service.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,8 @@ def generate_with_callback(inputs, callback=None, **kwargs):
757757
continuous_segments=continuous_segment,
758758
starts=starts,
759759
)
760+
# Convert to long immediately after generation (tokens must be integers)
761+
curr_token = curr_token.long()
760762
frames.append(curr_token[0:1,])
761763

762764
def _pad_audio_token(token):
@@ -791,6 +793,8 @@ def _pad_audio_token(token):
791793
continuous_segments=None,
792794
starts=None,
793795
)
796+
# Convert to long immediately after generation (tokens must be integers)
797+
curr_token = curr_token.long()
794798
if torch.any(curr_token[0:1, :] >= pipeline.config.audio_eos_id):
795799
break
796800
frames.append(curr_token[0:1,])

0 commit comments

Comments
 (0)