Skip to content

Commit bfdb48b

Browse files
authored
[ENH] Use GCS with aws-sdk-go-v2 (#5878)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - N/A - New functionality - Hacks the S3 client initialization in SysDB so that it can support GCS - Source: https://stackoverflow.com/questions/73717477/gcp-cloud-storage-golang-aws-sdk2-upload-file-with-s3-interoperability-creds ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent ce07447 commit bfdb48b

File tree

2 files changed

+73
-22
lines changed

2 files changed

+73
-22
lines changed

go/cmd/coordinator/cmd.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ func init() {
8080
Cmd.Flags().StringVar(&conf.MetaStoreConfig.AccessKeyID, "s3-access-key-id", "", "S3 access key ID")
8181
Cmd.Flags().StringVar(&conf.MetaStoreConfig.SecretAccessKey, "s3-secret-access-key", "", "S3 secret access key")
8282
Cmd.Flags().BoolVar(&conf.MetaStoreConfig.ForcePathStyle, "s3-force-path-style", false, "S3 force path style")
83+
Cmd.Flags().BoolVar(&conf.MetaStoreConfig.GCSInterop, "s3-gcs-interop", false, "Enable Google Cloud Storage support for S3 client")
8384

8485
// Version file
8586
Cmd.Flags().BoolVar(&conf.VersionFileEnabled, "version-file-enabled", false, "Enable version file")

go/pkg/sysdb/metastore/s3/impl.go

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@ import (
66
"errors"
77
"fmt"
88
"io"
9+
"net/http"
910
"strings"
11+
"time"
1012

1113
"github.com/aws/aws-sdk-go-v2/aws"
14+
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
1215
"github.com/aws/aws-sdk-go-v2/config"
1316
"github.com/aws/aws-sdk-go-v2/credentials"
1417
"github.com/aws/aws-sdk-go-v2/service/s3"
@@ -20,6 +23,52 @@ import (
2023
"google.golang.org/protobuf/proto"
2124
)
2225

26+
// NOTE(sicheng): As a temporary solution we use the AWS SDK with GCS, but this approach needs a few tweaks:
27+
// https://stackoverflow.com/questions/73717477/gcp-cloud-storage-golang-aws-sdk2-upload-file-with-s3-interoperability-creds
28+
//
29+
// In summary, the AWS SDK we are using (1.36.3) is not fully compatible with the GCS because it uses and additional header
30+
// for signing. We need to add a middleware to remove the offending header, as suggested by the thread above.
31+
//
32+
// If we are upgrading AWS SDK to version higher than 1.73.0, we need additional tweak:
33+
// cfg.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
34+
type RecalculateV4Signature struct {
35+
next http.RoundTripper
36+
signer *v4.Signer
37+
cfg aws.Config
38+
}
39+
40+
// NOTE(sicheng): Code borrowed from https://stackoverflow.com/questions/73717477/gcp-cloud-storage-golang-aws-sdk2-upload-file-with-s3-interoperability-creds
41+
func (lt *RecalculateV4Signature) RoundTrip(req *http.Request) (*http.Response, error) {
42+
// store for later use
43+
val := req.Header.Get("Accept-Encoding")
44+
45+
// delete the header so the header doesn't account for in the signature
46+
req.Header.Del("Accept-Encoding")
47+
48+
// sign with the same date
49+
timeString := req.Header.Get("X-Amz-Date")
50+
timeDate, err := time.Parse("20060102T150405Z", timeString)
51+
if err != nil {
52+
return nil, err
53+
}
54+
55+
creds, err := lt.cfg.Credentials.Retrieve(req.Context())
56+
if err != nil {
57+
return nil, err
58+
}
59+
60+
err = lt.signer.SignHTTP(req.Context(), creds, req, v4.GetPayloadHash(req.Context()), "s3", lt.cfg.Region, timeDate)
61+
if err != nil {
62+
return nil, err
63+
}
64+
65+
// Reset Accept-Encoding if desired
66+
req.Header.Set("Accept-Encoding", val)
67+
68+
// follows up the original round tripper
69+
return lt.next.RoundTrip(req)
70+
}
71+
2372
// Path to Version Files in S3.
2473
// Example:
2574
// s3://<bucket-name>/tenant/<tenant_id>/databases/<database_id>/collections/<collection_id>/versionfiles/file_name
@@ -37,6 +86,7 @@ type S3MetaStoreConfig struct {
3786
AccessKeyID string
3887
SecretAccessKey string
3988
ForcePathStyle bool
89+
GCSInterop bool
4090
}
4191

4292
type S3MetaStoreInterface interface {
@@ -101,35 +151,35 @@ func NewS3MetaStore(ctx context.Context, cfg S3MetaStoreConfig) (*S3MetaStore, e
101151
var awsConfig aws.Config
102152
var err error
103153

154+
awsConfigParts := []func(*config.LoadOptions) error{config.WithRegion(region)}
155+
104156
if cfg.AccessKeyID != "" && cfg.SecretAccessKey != "" {
105157
creds := credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, "")
106-
if cfg.Endpoint != "" {
107-
awsConfig, err = config.LoadDefaultConfig(ctx,
108-
config.WithCredentialsProvider(creds),
109-
config.WithRegion(region),
110-
)
111-
} else {
112-
awsConfig, err = config.LoadDefaultConfig(ctx,
113-
config.WithCredentialsProvider(creds),
114-
config.WithRegion(region),
115-
)
116-
}
117-
} else {
118-
if cfg.Endpoint != "" {
119-
awsConfig, err = config.LoadDefaultConfig(ctx,
120-
config.WithRegion(region),
121-
)
122-
} else {
123-
awsConfig, err = config.LoadDefaultConfig(ctx,
124-
config.WithRegion(region),
125-
)
126-
}
158+
awsConfigParts = append(awsConfigParts, config.WithCredentialsProvider(creds))
127159
}
128160

161+
if cfg.GCSInterop && cfg.Endpoint != "" {
162+
resolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...any) (aws.Endpoint, error) {
163+
return aws.Endpoint{
164+
URL: cfg.Endpoint,
165+
SigningRegion: cfg.Region,
166+
Source: aws.EndpointSourceCustom,
167+
HostnameImmutable: true,
168+
}, nil
169+
})
170+
awsConfigParts = append(awsConfigParts, config.WithEndpointResolverWithOptions(resolver))
171+
}
172+
173+
awsConfig, err = config.LoadDefaultConfig(ctx, awsConfigParts...)
129174
if err != nil {
130175
return nil, err
131176
}
132177

178+
// Add middleware to remove offending header for signing for GCS Support
179+
if cfg.GCSInterop {
180+
awsConfig.HTTPClient = &http.Client{Transport: &RecalculateV4Signature{http.DefaultTransport, v4.NewSigner(), awsConfig}}
181+
}
182+
133183
// Create S3 client with optional path-style addressing and custom endpoint
134184
otelaws.AppendMiddlewares(&awsConfig.APIOptions)
135185
s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) {
@@ -159,7 +209,7 @@ func NewS3MetaStore(ctx context.Context, cfg S3MetaStoreConfig) (*S3MetaStore, e
159209
}
160210

161211
// Verify we have access to the bucket
162-
_, err = s3Client.HeadBucket(ctx, &s3.HeadBucketInput{
212+
_, err = s3Client.ListObjects(ctx, &s3.ListObjectsInput{
163213
Bucket: aws.String(bucketName),
164214
})
165215
if err != nil {

0 commit comments

Comments
 (0)