Skip to content

Commit f3f3488

Browse files
committed
removed old httpProtocol/httpTransport code, removed authInfo as it's replaced by interceptors, and added basic & sigv4 auth ref
1 parent 19c0c8b commit f3f3488

17 files changed

+332
-612
lines changed

gremlin-go/driver/auth.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
*/
19+
20+
package gremlingo
21+
22+
import (
23+
"context"
24+
"encoding/base64"
25+
"time"
26+
27+
"github.com/aws/aws-sdk-go-v2/aws"
28+
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
29+
"github.com/aws/aws-sdk-go-v2/config"
30+
)
31+
32+
// BasicAuth returns a RequestInterceptor that adds Basic authentication header.
33+
func BasicAuth(username, password string) RequestInterceptor {
34+
encoded := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
35+
return func(req *HttpRequest) error {
36+
req.Headers.Set(HeaderAuthorization, "Basic "+encoded)
37+
return nil
38+
}
39+
}
40+
41+
// Sigv4Auth returns a RequestInterceptor that signs requests using AWS SigV4.
42+
// It uses the default AWS credential chain (env vars, shared config, IAM role, etc.)
43+
func Sigv4Auth(region, service string) RequestInterceptor {
44+
return Sigv4AuthWithCredentials(region, service, nil)
45+
}
46+
47+
// Sigv4AuthWithCredentials returns a RequestInterceptor that signs requests using AWS SigV4
48+
// with the provided credentials provider. If provider is nil, uses default credential chain.
49+
func Sigv4AuthWithCredentials(region, service string, credentialsProvider aws.CredentialsProvider) RequestInterceptor {
50+
return func(req *HttpRequest) error {
51+
ctx := context.Background()
52+
53+
creds, err := resolveCredentials(ctx, region, credentialsProvider)
54+
if err != nil {
55+
return err
56+
}
57+
58+
signer := v4.NewSigner()
59+
stdReq := req.ToStdRequest()
60+
stdReq.Body = nil // Body is handled separately via payload hash
61+
62+
if err := signer.SignHTTP(ctx, creds, stdReq, req.PayloadHash(), service, region, time.Now()); err != nil {
63+
return err
64+
}
65+
66+
// Copy signed headers back to HttpRequest
67+
for k, v := range stdReq.Header {
68+
req.Headers[k] = v
69+
}
70+
71+
return nil
72+
}
73+
}
74+
75+
func resolveCredentials(ctx context.Context, region string, provider aws.CredentialsProvider) (aws.Credentials, error) {
76+
if provider != nil {
77+
return provider.Retrieve(ctx)
78+
}
79+
80+
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
81+
if err != nil {
82+
return aws.Credentials{}, err
83+
}
84+
return cfg.Credentials.Retrieve(ctx)
85+
}

gremlin-go/driver/authInfo.go

Lines changed: 0 additions & 95 deletions
This file was deleted.

gremlin-go/driver/auth_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
*/
19+
20+
package gremlingo
21+
22+
import (
23+
"context"
24+
"encoding/base64"
25+
"strings"
26+
"testing"
27+
28+
"github.com/aws/aws-sdk-go-v2/aws"
29+
"github.com/stretchr/testify/assert"
30+
)
31+
32+
func createMockRequest() *HttpRequest {
33+
req, _ := NewHttpRequest("POST", "https://localhost:8182/gremlin")
34+
req.Headers.Set("Content-Type", graphBinaryMimeType)
35+
req.Headers.Set("Accept", graphBinaryMimeType)
36+
req.Body = []byte(`{"gremlin":"g.V()"}`)
37+
return req
38+
}
39+
40+
func TestBasicAuth(t *testing.T) {
41+
t.Run("adds authorization header", func(t *testing.T) {
42+
req := createMockRequest()
43+
assert.Empty(t, req.Headers.Get(HeaderAuthorization))
44+
45+
interceptor := BasicAuth("username", "password")
46+
err := interceptor(req)
47+
48+
assert.NoError(t, err)
49+
authHeader := req.Headers.Get(HeaderAuthorization)
50+
assert.True(t, strings.HasPrefix(authHeader, "Basic "))
51+
52+
// Verify encoding
53+
encoded := strings.TrimPrefix(authHeader, "Basic ")
54+
decoded, err := base64.StdEncoding.DecodeString(encoded)
55+
assert.NoError(t, err)
56+
assert.Equal(t, "username:password", string(decoded))
57+
})
58+
}
59+
60+
// mockCredentialsProvider implements aws.CredentialsProvider for testing
61+
type mockCredentialsProvider struct {
62+
accessKey string
63+
secretKey string
64+
sessionToken string
65+
}
66+
67+
func (m *mockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
68+
return aws.Credentials{
69+
AccessKeyID: m.accessKey,
70+
SecretAccessKey: m.secretKey,
71+
SessionToken: m.sessionToken,
72+
}, nil
73+
}
74+
75+
func TestSigv4Auth(t *testing.T) {
76+
t.Run("adds signed headers", func(t *testing.T) {
77+
req := createMockRequest()
78+
assert.Empty(t, req.Headers.Get("Authorization"))
79+
assert.Empty(t, req.Headers.Get("X-Amz-Date"))
80+
81+
provider := &mockCredentialsProvider{
82+
accessKey: "MOCK_ACCESS_KEY",
83+
secretKey: "MOCK_SECRET_KEY",
84+
}
85+
interceptor := Sigv4AuthWithCredentials("us-west-2", "neptune-db", provider)
86+
err := interceptor(req)
87+
88+
assert.NoError(t, err)
89+
assert.NotEmpty(t, req.Headers.Get("X-Amz-Date"))
90+
authHeader := req.Headers.Get("Authorization")
91+
assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential=MOCK_ACCESS_KEY"))
92+
assert.Contains(t, authHeader, "us-west-2/neptune-db/aws4_request")
93+
assert.Contains(t, authHeader, "Signature=")
94+
})
95+
96+
t.Run("adds session token when provided", func(t *testing.T) {
97+
req := createMockRequest()
98+
assert.Empty(t, req.Headers.Get("X-Amz-Security-Token"))
99+
100+
provider := &mockCredentialsProvider{
101+
accessKey: "MOCK_ACCESS_KEY",
102+
secretKey: "MOCK_SECRET_KEY",
103+
sessionToken: "MOCK_SESSION_TOKEN",
104+
}
105+
interceptor := Sigv4AuthWithCredentials("us-west-2", "neptune-db", provider)
106+
err := interceptor(req)
107+
108+
assert.NoError(t, err)
109+
assert.Equal(t, "MOCK_SESSION_TOKEN", req.Headers.Get("X-Amz-Security-Token"))
110+
authHeader := req.Headers.Get("Authorization")
111+
assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential="))
112+
assert.Contains(t, authHeader, "Signature=")
113+
})
114+
}

gremlin-go/driver/client.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ type ClientSettings struct {
3636
LogVerbosity LogVerbosity
3737
Logger Logger
3838
Language language.Tag
39-
AuthInfo AuthInfoProvider
4039
TlsConfig *tls.Config
4140
ConnectionTimeout time.Duration
4241
EnableCompression bool
@@ -72,7 +71,6 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C
7271
LogVerbosity: Info,
7372
Logger: &defaultLogger{},
7473
Language: language.English,
75-
AuthInfo: &AuthInfo{},
7674
TlsConfig: &tls.Config{},
7775
ConnectionTimeout: connectionTimeoutDefault,
7876
EnableCompression: false,
@@ -85,7 +83,6 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C
8583
}
8684

8785
connSettings := &connectionSettings{
88-
authInfo: settings.AuthInfo,
8986
tlsConfig: settings.TlsConfig,
9087
connectionTimeout: settings.ConnectionTimeout,
9188
enableCompression: settings.EnableCompression,

gremlin-go/driver/client_test.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,13 @@ func TestClient(t *testing.T) {
3030
// Integration test variables.
3131
testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl)
3232
testNoAuthEnable := getEnvOrDefaultBool("RUN_INTEGRATION_TESTS", true)
33-
testNoAuthAuthInfo := &AuthInfo{}
3433
testNoAuthTlsConfig := &tls.Config{}
3534

3635
t.Run("Test client.SubmitWithOptions()", func(t *testing.T) {
3736
skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable)
3837
client, err := NewClient(testNoAuthUrl,
3938
func(settings *ClientSettings) {
4039
settings.TlsConfig = testNoAuthTlsConfig
41-
settings.AuthInfo = testNoAuthAuthInfo
4240
})
4341
assert.NoError(t, err)
4442
assert.NotNil(t, client)
@@ -58,7 +56,6 @@ func TestClient(t *testing.T) {
5856
client, err := NewClient(testNoAuthUrl,
5957
func(settings *ClientSettings) {
6058
settings.TlsConfig = testNoAuthTlsConfig
61-
settings.AuthInfo = testNoAuthAuthInfo
6259
})
6360
assert.NoError(t, err)
6461
assert.NotNil(t, client)
@@ -74,7 +71,6 @@ func TestClient(t *testing.T) {
7471
client, err := NewClient(testNoAuthUrl,
7572
func(settings *ClientSettings) {
7673
settings.TlsConfig = testNoAuthTlsConfig
77-
settings.AuthInfo = testNoAuthAuthInfo
7874
settings.TraversalSource = testServerModernGraphAlias
7975
})
8076
assert.NoError(t, err)
@@ -97,7 +93,6 @@ func TestClient(t *testing.T) {
9793
client, err := NewClient(testNoAuthUrl,
9894
func(settings *ClientSettings) {
9995
settings.TlsConfig = testNoAuthTlsConfig
100-
settings.AuthInfo = testNoAuthAuthInfo
10196
settings.TraversalSource = testServerModernGraphAlias
10297
})
10398
assert.NoError(t, err)
@@ -122,7 +117,6 @@ func TestClient(t *testing.T) {
122117
client, err := NewClient(testNoAuthUrl,
123118
func(settings *ClientSettings) {
124119
settings.TlsConfig = testNoAuthTlsConfig
125-
settings.AuthInfo = testNoAuthAuthInfo
126120
settings.TraversalSource = testServerModernGraphAlias
127121
})
128122
assert.NoError(t, err)
@@ -147,7 +141,6 @@ func TestClient(t *testing.T) {
147141
client, err := NewClient(testNoAuthUrl,
148142
func(settings *ClientSettings) {
149143
settings.TlsConfig = testNoAuthTlsConfig
150-
settings.AuthInfo = testNoAuthAuthInfo
151144
settings.TraversalSource = testServerModernGraphAlias
152145
})
153146

@@ -170,7 +163,6 @@ func TestClient(t *testing.T) {
170163
client, err := NewClient(testNoAuthUrl,
171164
func(settings *ClientSettings) {
172165
settings.TlsConfig = testNoAuthTlsConfig
173-
settings.AuthInfo = testNoAuthAuthInfo
174166
settings.TraversalSource = testServerCrewGraphAlias
175167
})
176168

gremlin-go/driver/connection.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
)
2626

2727
type connectionSettings struct {
28-
authInfo AuthInfoProvider
2928
tlsConfig *tls.Config
3029
connectionTimeout time.Duration
3130
enableCompression bool

0 commit comments

Comments
 (0)