Skip to content

Commit 9472ad4

Browse files
authored
Merge pull request #45 from zhonghuihong/main
websocket tts
2 parents 01b3572 + 3ca1c59 commit 9472ad4

File tree

4 files changed

+122
-21
lines changed

4 files changed

+122
-21
lines changed

config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ TTS:
9494
appid: "你的appid"
9595
token: 你的access_token
9696
cluster: 你的cluster
97+
GoSherpaTTS:
98+
type: gosherpa
99+
cluster: "ws://127.0.0.1:8848/tts"
100+
output_dir: "tmp/"
97101

98102
# LLM配置
99103
LLM:
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package gosherpa
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"github.com/gorilla/websocket"
7+
"os"
8+
"path/filepath"
9+
"time"
10+
"xiaozhi-server-go/src/core/providers/tts"
11+
)
12+
13+
// Provider Sherpa TTS提供者实现
14+
type Provider struct {
15+
*tts.BaseProvider
16+
conn *websocket.Conn
17+
}
18+
19+
// NewProvider 创建Sherpa TTS提供者
20+
func NewProvider(config *tts.Config, deleteFile bool) (*Provider, error) {
21+
base := tts.NewBaseProvider(config, deleteFile)
22+
23+
dialer := websocket.Dialer{
24+
HandshakeTimeout: 10 * time.Second, // 设置握手超时
25+
}
26+
conn, _, err := dialer.DialContext(context.Background(), config.Cluster, map[string][]string{})
27+
if err != nil {
28+
return nil, err
29+
}
30+
31+
return &Provider{
32+
BaseProvider: base,
33+
conn: conn,
34+
}, nil
35+
}
36+
37+
// ToTTS 将文本转换为音频文件,并返回文件路径
38+
func (p *Provider) ToTTS(text string) (string, error) {
39+
// 获取配置的声音,如果未配置则使用默认值
40+
SherpaTTSStartTime := time.Now()
41+
42+
// 创建临时文件路径用于保存 SherpaTTS 生成的 MP3
43+
outputDir := p.BaseProvider.Config().OutputDir
44+
if outputDir == "" {
45+
outputDir = os.TempDir() // Use system temp dir if not configured
46+
}
47+
// Ensure output directory exists
48+
if err := os.MkdirAll(outputDir, 0755); err != nil {
49+
return "", fmt.Errorf("创建输出目录失败 '%s': %v", outputDir, err)
50+
}
51+
// Use a unique filename
52+
tempFile := filepath.Join(outputDir, fmt.Sprintf("go_sherpa_tts_%d.wav", time.Now().UnixNano()))
53+
54+
p.conn.WriteMessage(websocket.TextMessage, []byte(text))
55+
_, bytes, err := p.conn.ReadMessage()
56+
57+
if err != nil {
58+
return "", fmt.Errorf("go-sherpa-tts 获取音频流失败: %v", err)
59+
}
60+
61+
ttsDuration := time.Since(SherpaTTSStartTime)
62+
fmt.Println(fmt.Sprintf("go-sherpa-tts 语音合成完成,耗时: %s", ttsDuration))
63+
64+
// 将音频数据写入临时文件
65+
err = os.WriteFile(tempFile, bytes, 0644)
66+
if err != nil {
67+
return "", fmt.Errorf("写入音频文件 '%s' 失败: %v", tempFile, err)
68+
}
69+
70+
// 检查文件是否成功创建
71+
if _, err := os.Stat(tempFile); os.IsNotExist(err) {
72+
return "", fmt.Errorf("go-sherpa-tts 未能创建音频文件: %s", tempFile)
73+
}
74+
75+
// Return the path to the generated audio file
76+
return tempFile, nil
77+
}
78+
79+
func init() {
80+
// 注册Sherpa TTS提供者
81+
tts.Register("gosherpa", func(config *tts.Config, deleteFile bool) (tts.Provider, error) {
82+
return NewProvider(config, deleteFile)
83+
})
84+
}

src/core/utils/audio.go

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"io"
66
"os"
77
"path/filepath"
8+
"strings"
89
"sync"
910

1011
"github.com/hajimehoshi/go-mp3"
@@ -384,34 +385,45 @@ func AudioToPCMData(audioFile string) ([][]byte, float64, error) {
384385

385386
// AudioToOpusData 将音频文件转换为Opus数据块
386387
func AudioToOpusData(audioFile string) ([][]byte, float64, error) {
387-
// 先将MP3转为PCM
388-
pcmData, duration, err := AudioToPCMData(audioFile)
389-
if err != nil {
390-
return nil, 0, fmt.Errorf("PCM转换失败: %v", err)
391-
}
392388

393-
if len(pcmData) == 0 {
394-
return nil, 0, fmt.Errorf("PCM转换结果为空")
395-
}
396-
397-
// 打开MP3文件获取采样率
398-
file, err := os.Open(audioFile)
399-
if err != nil {
400-
return nil, 0, fmt.Errorf("打开音频文件失败: %v", err)
401-
}
402-
defer file.Close()
403-
404-
// 检查MP3文件格式是否有效
405-
_, err = mp3.NewDecoder(file)
406-
if err != nil {
407-
return nil, 0, fmt.Errorf("创建MP3解码器失败: %v", err)
408-
}
389+
var pcmData [][]byte
390+
var err error
391+
var duration float64
409392

410393
// 获取采样率 (固定使用24000Hz作为Opus编码的采样率)
411394
// 如果采样率不是24000Hz,PCMSlicesToOpusData会处理重采样
412395
opusSampleRate := 24000
413396
channels := 1
414397

398+
if strings.HasSuffix(audioFile, ".mp3") {
399+
// 先将MP3转为PCM
400+
pcmData, duration, err = AudioToPCMData(audioFile)
401+
if err != nil {
402+
return nil, 0, fmt.Errorf("PCM转换失败: %v", err)
403+
}
404+
405+
if len(pcmData) == 0 {
406+
return nil, 0, fmt.Errorf("PCM转换结果为空")
407+
}
408+
409+
// 打开MP3文件获取采样率
410+
file, err := os.Open(audioFile)
411+
if err != nil {
412+
return nil, 0, fmt.Errorf("打开音频文件失败: %v", err)
413+
}
414+
defer file.Close()
415+
416+
// 检查MP3文件格式是否有效
417+
_, err = mp3.NewDecoder(file)
418+
if err != nil {
419+
return nil, 0, fmt.Errorf("创建MP3解码器失败: %v", err)
420+
}
421+
} else {
422+
var singlePcmData []byte
423+
singlePcmData, err = ReadPCMDataFromWavFile(audioFile)
424+
pcmData = [][]byte{singlePcmData}
425+
}
426+
415427
// 将PCM转换为Opus
416428
opusData, err := PCMSlicesToOpusData(pcmData, opusSampleRate, channels, 0)
417429
if err != nil {

src/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
_ "xiaozhi-server-go/src/core/providers/llm/openai"
2424
_ "xiaozhi-server-go/src/core/providers/tts/doubao"
2525
_ "xiaozhi-server-go/src/core/providers/tts/edge"
26+
_ "xiaozhi-server-go/src/core/providers/tts/gosherpa"
2627
_ "xiaozhi-server-go/src/core/providers/vlllm/ollama"
2728
_ "xiaozhi-server-go/src/core/providers/vlllm/openai"
2829

0 commit comments

Comments
 (0)