Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package grammes

import (
"crypto/tls"
"strconv"
"time"

Expand Down Expand Up @@ -102,3 +103,10 @@ func WithReadingWait(interval time.Duration) ClientConfiguration {
c.conn.SetReadingWait(interval)
}
}

// WithTLS sets the TLS config
func WithTLS(conf *tls.Config) ClientConfiguration {
return func(c *Client) {
c.conn.SetTLSConfig(conf)
}
}
16 changes: 16 additions & 0 deletions configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package grammes

import (
"crypto/tls"
"strconv"
"testing"
"time"
Expand Down Expand Up @@ -162,3 +163,18 @@ func TestWithReadingWait(t *testing.T) {
})
})
}

func TestWithTLS(t *testing.T) {
t.Parallel()

Convey("Given a tls config and dialer", t, func() {
dialer := &mockDialerStruct{}
dialer.tlsConfig = &tls.Config{}
Convey("And Dial is called with tls config", func() {
_, err := mockDial(dialer, WithTLS(dialer.tlsConfig))
Convey("Then no error should be encountered", func() {
So(err, ShouldBeNil)
})
})
})
}
6 changes: 5 additions & 1 deletion gremconnect/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

package gremconnect

import "time"
import (
"crypto/tls"
"time"
)

// Dialer will be used to dial in a connection
// between the client and gremlin server without
Expand Down Expand Up @@ -50,6 +53,7 @@ type Dialer interface {
SetPingInterval(interval time.Duration)
SetWritingWait(interval time.Duration)
SetReadingWait(interval time.Duration)
SetTLSConfig(conf *tls.Config)
}

// NewWebSocketDialer returns a new WebSocket dialer to use when
Expand Down
8 changes: 8 additions & 0 deletions gremconnect/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package gremconnect

import (
"crypto/tls"
"errors"
"net/http"
"strings"
Expand All @@ -36,6 +37,7 @@ import (
type WebSocket struct {
address string
conn *websocket.Conn
tlsConfig *tls.Config
auth *Auth
disposed bool
connected bool
Expand All @@ -54,6 +56,7 @@ type WebSocket struct {
func (ws *WebSocket) Connect() error {
var err error
dialer := websocket.Dialer{
TLSClientConfig: ws.tlsConfig,
WriteBufferSize: 1024 * 8, // Set up for large messages.
ReadBufferSize: 1024 * 8, // Set up for large messages.
HandshakeTimeout: 5 * time.Second,
Expand Down Expand Up @@ -200,3 +203,8 @@ func (ws *WebSocket) SetWritingWait(interval time.Duration) {
func (ws *WebSocket) SetReadingWait(interval time.Duration) {
ws.readingWait = interval
}

// SetReadingWait sets how long the reading will wait
func (ws *WebSocket) SetTLSConfig(conf *tls.Config) {
ws.tlsConfig = conf
}
2 changes: 2 additions & 0 deletions manager/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package manager

import (
"crypto/tls"
"testing"
"time"

Expand Down Expand Up @@ -48,6 +49,7 @@ func (*mockDialer) SetTimeout(time.Duration) {}
func (*mockDialer) SetPingInterval(time.Duration) {}
func (*mockDialer) SetWritingWait(time.Duration) {}
func (*mockDialer) SetReadingWait(time.Duration) {}
func (*mockDialer) SetTLSConfig(*tls.Config) {}

func TestSetLoggerQM(t *testing.T) {
Convey("Given a dialer, string executor and query manager", t, func() {
Expand Down
2 changes: 2 additions & 0 deletions quick/addvertex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package quick

import (
"crypto/tls"
"errors"
"testing"
"time"
Expand Down Expand Up @@ -80,6 +81,7 @@ func (*mockDialer) SetTimeout(time.Duration) {}
func (*mockDialer) SetPingInterval(time.Duration) {}
func (*mockDialer) SetWritingWait(time.Duration) {}
func (*mockDialer) SetReadingWait(time.Duration) {}
func (*mockDialer) SetTLSConfig(*tls.Config) {}

// MOCKQUERY

Expand Down
6 changes: 6 additions & 0 deletions testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package grammes

import (
"crypto/tls"
"errors"
"time"

Expand Down Expand Up @@ -105,6 +106,7 @@ type mockDialerStruct struct {
logger testLogger
address string
conn *websocket.Conn
tlsConfig *tls.Config
auth *gremconnect.Auth
disposed bool
connected bool
Expand Down Expand Up @@ -137,6 +139,7 @@ func (*mockDialerStruct) SetTimeout(time.Duration) {}
func (*mockDialerStruct) SetPingInterval(time.Duration) {}
func (*mockDialerStruct) SetWritingWait(time.Duration) {}
func (*mockDialerStruct) SetReadingWait(time.Duration) {}
func (*mockDialerStruct) SetTLSConfig(*tls.Config) {}

func mockDial(conn gremconnect.Dialer, cfgs ...ClientConfiguration) (*Client, error) {
c := setupClient()
Expand Down Expand Up @@ -167,6 +170,7 @@ func (*mockDialerWriteError) SetTimeout(time.Duration) {}
func (*mockDialerWriteError) SetPingInterval(time.Duration) {}
func (*mockDialerWriteError) SetWritingWait(time.Duration) {}
func (*mockDialerWriteError) SetReadingWait(time.Duration) {}
func (*mockDialerWriteError) SetTLSConfig(*tls.Config) {}

type mockDialerAuthError gremconnect.WebSocket

Expand All @@ -192,6 +196,7 @@ func (*mockDialerAuthError) SetTimeout(time.Duration) {}
func (*mockDialerAuthError) SetPingInterval(time.Duration) {}
func (*mockDialerAuthError) SetWritingWait(time.Duration) {}
func (*mockDialerAuthError) SetReadingWait(time.Duration) {}
func (*mockDialerAuthError) SetTLSConfig(*tls.Config) {}

type mockDialerReadError gremconnect.WebSocket

Expand All @@ -217,3 +222,4 @@ func (*mockDialerReadError) SetTimeout(time.Duration) {}
func (*mockDialerReadError) SetPingInterval(time.Duration) {}
func (*mockDialerReadError) SetWritingWait(time.Duration) {}
func (*mockDialerReadError) SetReadingWait(time.Duration) {}
func (*mockDialerReadError) SetTLSConfig(*tls.Config) {}