44 "bytes"
55 "context"
66 "crypto/tls"
7+ "errors"
78 "fmt"
89 "net"
910 "net/http"
@@ -14,6 +15,10 @@ import (
1415 "github.com/gorilla/websocket"
1516)
1617
18+ var (
19+ WSGrpcError = errors .New ("wsgrpc error" )
20+ )
21+
1722// ---------------------------------------
1823// 通用 websocketConn 实现 net.Conn 接口
1924// ---------------------------------------
@@ -35,16 +40,20 @@ func (c *websocketConn) Read(p []byte) (int, error) {
3540 if c .readBuffer .Len () == 0 {
3641 messageType , data , err := c .ws .ReadMessage ()
3742 if err != nil {
38- return 0 , err
43+ return 0 , errors . Join ( err , errors . New ( "wsgrpc read message error" ), WSGrpcError )
3944 }
4045 // 只接受二进制数据
4146 if messageType != websocket .BinaryMessage {
42- return 0 , fmt .Errorf ("unexpected message type: %d" , messageType )
47+ return 0 , errors . Join ( fmt .Errorf ("unexpected message type: %d" , messageType ), WSGrpcError )
4348 }
4449 c .readBuffer .Write (data )
4550 }
4651
47- return c .readBuffer .Read (p )
52+ if n , err := c .readBuffer .Read (p ); err != nil {
53+ return n , errors .Join (err , WSGrpcError )
54+ } else {
55+ return n , nil
56+ }
4857}
4958
5059// Write 将数据作为单条二进制消息发送
@@ -54,14 +63,18 @@ func (c *websocketConn) Write(p []byte) (int, error) {
5463
5564 err := c .ws .WriteMessage (websocket .BinaryMessage , p )
5665 if err != nil {
57- return 0 , err
66+ return 0 , errors . Join ( err , errors . New ( "wsgrpc write message error" ), WSGrpcError )
5867 }
5968 return len (p ), nil
6069}
6170
6271// Close 关闭 websocket 连接
6372func (c * websocketConn ) Close () error {
64- return c .ws .Close ()
73+ err := c .ws .Close ()
74+ if err != nil {
75+ return errors .Join (err , errors .New ("wsgrpc close error" ), WSGrpcError )
76+ }
77+ return nil
6578}
6679
6780// LocalAddr 返回本地地址,通过 websocket 底层连接获取
@@ -83,9 +96,12 @@ func (c *websocketConn) RemoteAddr() net.Addr {
8396// SetDeadline 同时设置读写超时
8497func (c * websocketConn ) SetDeadline (t time.Time ) error {
8598 if err := c .ws .SetReadDeadline (t ); err != nil {
86- return err
99+ return errors . Join ( err , errors . New ( "wsgrpc set read deadline error" ), WSGrpcError )
87100 }
88- return c .ws .SetWriteDeadline (t )
101+ if err := c .ws .SetWriteDeadline (t ); err != nil {
102+ return errors .Join (err , errors .New ("wsgrpc set write deadline error" ), WSGrpcError )
103+ }
104+ return nil
89105}
90106
91107// SetReadDeadline 设置读超时
@@ -101,18 +117,26 @@ func (c *websocketConn) SetWriteDeadline(t time.Time) error {
101117// ---------------------------------------
102118// 客户端 WebSocket Dialer
103119// ---------------------------------------
120+ type LogInterface interface {
121+ Infof (format string , args ... interface {})
122+ Errorf (format string , args ... interface {})
123+ Tracef (format string , args ... interface {})
124+ }
104125
105126// WebsocketDialer 返回一个可以用于 grpc.WithContextDialer 的拨号函数;该函数通过 websocket 建立连接。
106127// 参数 url 表示 websocket 服务器地址;header 可用于传递额外的 header 参数。
107- func WebsocketDialer (url string , header http.Header , insecure bool ) func (ctx context.Context , addr string ) (net.Conn , error ) {
128+ func WebsocketDialer (url string , header http.Header , insecure bool , log LogInterface ) func (ctx context.Context , addr string ) (net.Conn , error ) {
108129 return func (ctx context.Context , addr string ) (net.Conn , error ) {
109130 dialer := websocket.Dialer {
110131 TLSClientConfig : & tls.Config {InsecureSkipVerify : insecure },
111132 }
133+ log .Tracef ("dialing websocket server [%s]" , url )
112134 ws , _ , err := dialer .DialContext (ctx , url , header )
113135 if err != nil {
114- return nil , err
136+ log .Errorf ("wsgrpc dialer error: %v" , err )
137+ return nil , errors .Join (err , errors .New ("wsgrpc dialer error" ), WSGrpcError )
115138 }
139+ log .Tracef ("websocket connection connect done" )
116140 return & websocketConn {ws : ws }, nil
117141 }
118142}
@@ -160,11 +184,11 @@ func (l *WSListener) Accept() (net.Conn, error) {
160184 select {
161185 case conn , ok := <- l .connCh :
162186 if ! ok {
163- return nil , fmt .Errorf ("listener closed" )
187+ return nil , errors . Join ( fmt .Errorf ("listener closed" ), WSGrpcError )
164188 }
165189 return conn , nil
166190 case <- l .done :
167- return nil , fmt .Errorf ("listener closed" )
191+ return nil , errors . Join ( fmt .Errorf ("listener closed" ), WSGrpcError )
168192 }
169193}
170194
0 commit comments