11package lars
22
33import (
4- "bytes"
54 "fmt"
5+ "log"
6+ "net/http"
67 "net/http/httptest"
78 "testing"
89
9- "golang.org/x/net /websocket"
10+ "github.com/gorilla /websocket"
1011 . "gopkg.in/go-playground/assert.v1"
1112)
1213
@@ -22,14 +23,28 @@ import (
2223//
2324
2425func TestWebsockets (t * testing.T ) {
26+
27+ origin := "http://localhost"
28+
29+ var upgrader = websocket.Upgrader {
30+ ReadBufferSize : 1024 ,
31+ WriteBufferSize : 1024 ,
32+ CheckOrigin : func (r * http.Request ) bool {
33+ o := r .Header .Get (Origin )
34+ return o == origin
35+ },
36+ }
37+
2538 l := New ()
26- l .WebSocket ("/ws" , func (c Context ) {
39+ l .WebSocket (upgrader , "/ws" , func (c Context ) {
2740
28- recv := make ([]byte , 1000 )
41+ messageType , b , err := c .WebSocket ().ReadMessage ()
42+ if err != nil {
43+ return
44+ }
2945
30- i , err := c .WebSocket ().Read (recv )
3146 if err == nil {
32- _ , err := c .WebSocket ().Write ( recv [: i ] )
47+ err := c .WebSocket ().WriteMessage ( messageType , b )
3348 if err != nil {
3449 panic (err )
3550 }
@@ -40,19 +55,29 @@ func TestWebsockets(t *testing.T) {
4055 defer server .Close ()
4156
4257 addr := server .Listener .Addr ().String ()
43- origin := "http://localhost"
58+
59+ header := make (http.Header , 0 )
60+ header .Set (Origin , origin )
4461
4562 url := fmt .Sprintf ("ws://%s/ws" , addr )
46- ws , err := websocket .Dial (url , "" , origin )
63+ ws , _ , err := websocket .DefaultDialer .Dial (url , header )
64+ if err != nil {
65+ log .Fatal ("dial:" , err )
66+ }
4767 Equal (t , err , nil )
4868
4969 defer ws .Close ()
5070
51- _ , err = ws .Write ( []byte ("websockets in action!" ))
71+ err = ws .WriteMessage ( websocket . TextMessage , []byte ("websockets in action!" ))
5272 Equal (t , err , nil )
5373
54- buf := new (bytes.Buffer )
55- _ , err = buf .ReadFrom (ws )
74+ typ , b , err := ws .ReadMessage ()
5675 Equal (t , err , nil )
57- Equal (t , "websockets in action!" , buf .String ())
76+ Equal (t , typ , websocket .TextMessage )
77+ Equal (t , "websockets in action!" , string (b ))
78+
79+ wsBad , res , err := websocket .DefaultDialer .Dial (url , nil )
80+ NotEqual (t , err , nil )
81+ Equal (t , wsBad , nil )
82+ Equal (t , res .StatusCode , http .StatusForbidden )
5883}
0 commit comments