Skip to content

Commit f8a1be7

Browse files
joeybloggsjoeybloggs
authored andcommitted
Merge branch 'update-to-gorilla-websocket'
2 parents 78ede96 + e92bff0 commit f8a1be7

File tree

4 files changed

+48
-24
lines changed

4 files changed

+48
-24
lines changed

context.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ import (
99
"strings"
1010
"time"
1111

12+
"github.com/gorilla/websocket"
1213
"golang.org/x/net/context"
13-
"golang.org/x/net/websocket"
1414
)
1515

1616
// Param is a single URL parameter, consisting of a key and a value.

group.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
package lars
22

33
import (
4-
"net/http"
54
"strconv"
65
"strings"
76

8-
"golang.org/x/net/websocket"
7+
"github.com/gorilla/websocket"
98
)
109

1110
// IRouteGroup interface for router group
@@ -27,7 +26,7 @@ type IRoutes interface {
2726
Head(string, ...Handler)
2827
Connect(string, ...Handler)
2928
Trace(string, ...Handler)
30-
WebSocket(string, Handler)
29+
WebSocket(websocket.Upgrader, string, Handler)
3130
}
3231

3332
// routeGroup struct containing all fields and methods for use.
@@ -159,22 +158,21 @@ func (g *routeGroup) Match(methods []string, path string, h ...Handler) {
159158
}
160159

161160
// WebSocket adds a websocket route
162-
func (g *routeGroup) WebSocket(path string, h Handler) {
161+
func (g *routeGroup) WebSocket(upgrader websocket.Upgrader, path string, h Handler) {
163162

164163
handler := g.lars.wrapHandler(h)
165164
g.Get(path, func(c Context) {
166165

167166
ctx := c.BaseContext()
167+
var err error
168168

169-
wss := websocket.Server{
170-
Handler: func(ws *websocket.Conn) {
171-
ctx.websocket = ws
172-
ctx.response.status = http.StatusSwitchingProtocols
173-
ctx.Next()
174-
},
169+
ctx.websocket, err = upgrader.Upgrade(ctx.response, ctx.request, nil)
170+
if err != nil {
171+
return
175172
}
176173

177-
wss.ServeHTTP(ctx.response, ctx.request)
174+
defer ctx.websocket.Close()
175+
c.Next()
178176
}, handler)
179177
}
180178

group_test.go

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package lars
22

33
import (
4-
"bytes"
54
"fmt"
5+
"log"
6+
"net/http"
67
"net/http/httptest"
78
"testing"
89

9-
"golang.org/x/net/websocket"
10+
"github.com/gorilla/websocket"
1011
. "gopkg.in/go-playground/assert.v1"
1112
)
1213

@@ -22,14 +23,28 @@ import (
2223
//
2324

2425
func TestWebsockets(t *testing.T) {
26+
27+
origin := "http://localhost"
28+
29+
var upgrader = websocket.Upgrader{
30+
ReadBufferSize: 1024,
31+
WriteBufferSize: 1024,
32+
CheckOrigin: func(r *http.Request) bool {
33+
o := r.Header.Get(Origin)
34+
return o == origin
35+
},
36+
}
37+
2538
l := New()
26-
l.WebSocket("/ws", func(c Context) {
39+
l.WebSocket(upgrader, "/ws", func(c Context) {
2740

28-
recv := make([]byte, 1000)
41+
messageType, b, err := c.WebSocket().ReadMessage()
42+
if err != nil {
43+
return
44+
}
2945

30-
i, err := c.WebSocket().Read(recv)
3146
if err == nil {
32-
_, err := c.WebSocket().Write(recv[:i])
47+
err := c.WebSocket().WriteMessage(messageType, b)
3348
if err != nil {
3449
panic(err)
3550
}
@@ -40,19 +55,29 @@ func TestWebsockets(t *testing.T) {
4055
defer server.Close()
4156

4257
addr := server.Listener.Addr().String()
43-
origin := "http://localhost"
58+
59+
header := make(http.Header, 0)
60+
header.Set(Origin, origin)
4461

4562
url := fmt.Sprintf("ws://%s/ws", addr)
46-
ws, err := websocket.Dial(url, "", origin)
63+
ws, _, err := websocket.DefaultDialer.Dial(url, header)
64+
if err != nil {
65+
log.Fatal("dial:", err)
66+
}
4767
Equal(t, err, nil)
4868

4969
defer ws.Close()
5070

51-
_, err = ws.Write([]byte("websockets in action!"))
71+
err = ws.WriteMessage(websocket.TextMessage, []byte("websockets in action!"))
5272
Equal(t, err, nil)
5373

54-
buf := new(bytes.Buffer)
55-
_, err = buf.ReadFrom(ws)
74+
typ, b, err := ws.ReadMessage()
5675
Equal(t, err, nil)
57-
Equal(t, "websockets in action!", buf.String())
76+
Equal(t, typ, websocket.TextMessage)
77+
Equal(t, "websockets in action!", string(b))
78+
79+
wsBad, res, err := websocket.DefaultDialer.Dial(url, nil)
80+
NotEqual(t, err, nil)
81+
Equal(t, wsBad, nil)
82+
Equal(t, res.StatusCode, http.StatusForbidden)
5883
}

lars.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ const (
7373
XForwardedFor = "X-Forwarded-For"
7474
XRealIP = "X-Real-Ip"
7575
Allow = "Allow"
76+
Origin = "Origin"
7677

7778
Gzip = "gzip"
7879

0 commit comments

Comments
 (0)