Skip to content

Commit bcd9081

Browse files
authored
feat: Add RemoteOAuth Token helper to refresh access_token from cloud environment (#1866)
Implements cloudquery/cloudquery-issues#1978 (internal issue)
1 parent d1dd099 commit bcd9081

File tree

6 files changed

+495
-0
lines changed

6 files changed

+495
-0
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ require (
3535
go.opentelemetry.io/otel/sdk/metric v1.28.0
3636
go.opentelemetry.io/otel/trace v1.28.0
3737
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56
38+
golang.org/x/oauth2 v0.20.0
3839
golang.org/x/sync v0.7.0
3940
golang.org/x/text v0.16.0
4041
google.golang.org/grpc v1.65.0

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
184184
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
185185
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
186186
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
187+
golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo=
188+
golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
187189
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
188190
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
189191
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=

helpers/remoteoauth/token.go

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
package remoteoauth
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"net/http"
8+
"os"
9+
10+
cloudquery_api "github.com/cloudquery/cloudquery-api-go"
11+
"github.com/google/uuid"
12+
"golang.org/x/oauth2"
13+
)
14+
15+
func NewTokenSource(opts ...TokenSourceOption) (oauth2.TokenSource, error) {
16+
t := &tokenSource{}
17+
for _, opt := range opts {
18+
opt(t)
19+
}
20+
21+
if _, cloudEnabled := os.LookupEnv("CQ_CLOUD"); !cloudEnabled {
22+
return oauth2.StaticTokenSource(&t.currentToken), nil
23+
}
24+
25+
cloudToken, err := newCloudTokenSource(t.defaultContext)
26+
if err != nil {
27+
return nil, err
28+
}
29+
if t.noWrap {
30+
return cloudToken, nil
31+
}
32+
33+
return oauth2.ReuseTokenSource(nil, cloudToken), nil
34+
}
35+
36+
type tokenSource struct {
37+
defaultContext context.Context
38+
currentToken oauth2.Token
39+
noWrap bool
40+
}
41+
42+
type cloudTokenSource struct {
43+
defaultContext context.Context
44+
apiClient *cloudquery_api.ClientWithResponses
45+
46+
apiURL string
47+
apiToken string
48+
teamName string
49+
syncName string
50+
testConnUUID uuid.UUID
51+
syncRunUUID uuid.UUID
52+
connectorUUID uuid.UUID
53+
isTestConnection bool
54+
}
55+
56+
var _ oauth2.TokenSource = (*cloudTokenSource)(nil)
57+
58+
func newCloudTokenSource(defaultContext context.Context) (oauth2.TokenSource, error) {
59+
t := &cloudTokenSource{
60+
defaultContext: defaultContext,
61+
}
62+
if t.defaultContext == nil {
63+
t.defaultContext = context.Background()
64+
}
65+
66+
err := t.initCloudOpts()
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
t.apiClient, err = cloudquery_api.NewClientWithResponses(t.apiURL,
72+
cloudquery_api.WithRequestEditorFn(func(_ context.Context, req *http.Request) error {
73+
req.Header.Set("Authorization", "Bearer "+t.apiToken)
74+
return nil
75+
}))
76+
if err != nil {
77+
return nil, fmt.Errorf("failed to create api client: %w", err)
78+
}
79+
80+
return t, nil
81+
}
82+
83+
// Token returns a new token from the remote source using the default context.
84+
func (t *cloudTokenSource) Token() (*oauth2.Token, error) {
85+
return t.retrieveToken(t.defaultContext)
86+
}
87+
88+
func (t *cloudTokenSource) retrieveToken(ctx context.Context) (*oauth2.Token, error) {
89+
var oauthResp *cloudquery_api.ConnectorCredentialsResponseOAuth
90+
if !t.isTestConnection {
91+
resp, err := t.apiClient.GetSyncRunConnectorCredentialsWithResponse(ctx, t.teamName, t.syncName, t.syncRunUUID, t.connectorUUID)
92+
if err != nil {
93+
return nil, fmt.Errorf("failed to get sync run connector credentials: %w", err)
94+
}
95+
if resp.StatusCode() != http.StatusOK {
96+
if resp.JSON422 != nil {
97+
return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.JSON422.Message)
98+
}
99+
return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.Status())
100+
}
101+
oauthResp = resp.JSON200.Oauth
102+
} else {
103+
resp, err := t.apiClient.GetTestConnectionConnectorCredentialsWithResponse(ctx, t.teamName, t.testConnUUID, t.connectorUUID)
104+
if err != nil {
105+
return nil, fmt.Errorf("failed to get test connection connector credentials: %w", err)
106+
}
107+
if resp.StatusCode() != http.StatusOK {
108+
if resp.JSON422 != nil {
109+
return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.JSON422.Message)
110+
}
111+
return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.Status())
112+
}
113+
oauthResp = resp.JSON200.Oauth
114+
}
115+
116+
if oauthResp == nil {
117+
return nil, fmt.Errorf("missing oauth credentials in response")
118+
}
119+
120+
tok := &oauth2.Token{
121+
AccessToken: oauthResp.AccessToken,
122+
}
123+
if oauthResp.Expires != nil {
124+
tok.Expiry = *oauthResp.Expires
125+
}
126+
return tok, nil
127+
}
128+
129+
func (t *cloudTokenSource) initCloudOpts() error {
130+
var allErr error
131+
132+
t.apiToken = os.Getenv("CLOUDQUERY_API_KEY")
133+
if t.apiToken == "" {
134+
allErr = errors.Join(allErr, errors.New("CLOUDQUERY_API_KEY missing"))
135+
}
136+
t.apiURL = os.Getenv("CLOUDQUERY_API_URL")
137+
if t.apiURL == "" {
138+
t.apiURL = "https://api.cloudquery.io"
139+
}
140+
141+
t.teamName = os.Getenv("_CQ_TEAM_NAME")
142+
if t.teamName == "" {
143+
allErr = errors.Join(allErr, errors.New("_CQ_TEAM_NAME missing"))
144+
}
145+
t.syncName = os.Getenv("_CQ_SYNC_NAME")
146+
syncRunID := os.Getenv("_CQ_SYNC_RUN_ID")
147+
testConnID := os.Getenv("_CQ_SYNC_TEST_CONNECTION_ID")
148+
if testConnID == "" && syncRunID == "" {
149+
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID or _CQ_SYNC_RUN_ID missing"))
150+
} else if testConnID != "" && syncRunID != "" {
151+
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID and _CQ_SYNC_RUN_ID are mutually exclusive"))
152+
}
153+
154+
var err error
155+
if syncRunID != "" {
156+
if t.syncName == "" {
157+
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME missing"))
158+
}
159+
160+
t.syncRunUUID, err = uuid.Parse(syncRunID)
161+
if err != nil {
162+
allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_RUN_ID is not a valid UUID: %w", err))
163+
}
164+
}
165+
if testConnID != "" {
166+
if t.syncName != "" {
167+
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME should be empty"))
168+
}
169+
170+
t.testConnUUID, err = uuid.Parse(testConnID)
171+
if err != nil {
172+
allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_TEST_CONNECTION_ID is not a valid UUID: %w", err))
173+
}
174+
t.isTestConnection = true
175+
}
176+
177+
connectorID := os.Getenv("_CQ_CONNECTOR_ID")
178+
if connectorID == "" {
179+
allErr = errors.Join(allErr, errors.New("_CQ_CONNECTOR_ID missing"))
180+
} else {
181+
t.connectorUUID, err = uuid.Parse(connectorID)
182+
if err != nil {
183+
allErr = errors.Join(allErr, fmt.Errorf("_CQ_CONNECTOR_ID is not a valid UUID: %w", err))
184+
}
185+
}
186+
return allErr
187+
}

helpers/remoteoauth/token_test.go

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
package remoteoauth
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"os"
7+
"testing"
8+
"time"
9+
10+
"github.com/google/uuid"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
const testAPIKey = "test-key"
15+
16+
func TestLocalTokenAccess(t *testing.T) {
17+
r := require.New(t)
18+
_, cloud := os.LookupEnv("CQ_CLOUD")
19+
r.False(cloud, "CQ_CLOUD should not be set")
20+
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{}))
21+
r.NoError(err)
22+
tk, err := tok.Token()
23+
r.NoError(err)
24+
r.True(tk.Valid())
25+
r.Equal("token", tk.AccessToken)
26+
}
27+
28+
func TestFirstLocalTokenAccess(t *testing.T) {
29+
runID := uuid.NewString()
30+
connID := uuid.NewString()
31+
testURL := setupMockTokenServer(t, map[string]string{
32+
"/teams/the-team/syncs/the-sync/runs/" + runID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`,
33+
})
34+
setEnvs(t, map[string]string{
35+
"CQ_CLOUD": "1",
36+
"CLOUDQUERY_API_URL": testURL,
37+
"CLOUDQUERY_API_KEY": testAPIKey,
38+
"_CQ_TEAM_NAME": "the-team",
39+
"_CQ_SYNC_NAME": "the-sync",
40+
"_CQ_SYNC_RUN_ID": runID,
41+
"_CQ_CONNECTOR_ID": connID,
42+
})
43+
r := require.New(t)
44+
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{}))
45+
r.NoError(err)
46+
tk, err := tok.Token()
47+
r.NoError(err)
48+
r.True(tk.Valid())
49+
r.Equal("new-token", tk.AccessToken)
50+
}
51+
52+
func TestInvalidAPIKeyTokenAccess(t *testing.T) {
53+
runID := uuid.NewString()
54+
connID := uuid.NewString()
55+
testURL := setupMockTokenServer(t, nil)
56+
setEnvs(t, map[string]string{
57+
"CQ_CLOUD": "1",
58+
"CLOUDQUERY_API_URL": testURL,
59+
"CLOUDQUERY_API_KEY": "invalid",
60+
"_CQ_TEAM_NAME": "the-team",
61+
"_CQ_SYNC_NAME": "the-sync",
62+
"_CQ_SYNC_RUN_ID": runID,
63+
"_CQ_CONNECTOR_ID": connID,
64+
})
65+
r := require.New(t)
66+
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{}))
67+
r.NoError(err)
68+
tk, err := tok.Token()
69+
r.Nil(tk)
70+
r.False(tk.Valid())
71+
r.ErrorContains(err, "failed to get sync run connector credentials")
72+
}
73+
74+
func TestSyncRunTokenAccess(t *testing.T) {
75+
runID := uuid.NewString()
76+
connID := uuid.NewString()
77+
testURL := setupMockTokenServer(t, map[string]string{
78+
"/teams/the-team/syncs/the-sync/runs/" + runID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`,
79+
})
80+
setEnvs(t, map[string]string{
81+
"CQ_CLOUD": "1",
82+
"CLOUDQUERY_API_URL": testURL,
83+
"CLOUDQUERY_API_KEY": testAPIKey,
84+
"_CQ_TEAM_NAME": "the-team",
85+
"_CQ_SYNC_NAME": "the-sync",
86+
"_CQ_SYNC_RUN_ID": runID,
87+
"_CQ_CONNECTOR_ID": connID,
88+
})
89+
r := require.New(t)
90+
tok, err := NewTokenSource()
91+
r.NoError(err)
92+
tk, err := tok.Token()
93+
r.NoError(err)
94+
r.True(tk.Valid())
95+
r.Equal("new-token", tk.AccessToken)
96+
}
97+
98+
func TestTestConnectionTokenAccess(t *testing.T) {
99+
testID := uuid.NewString()
100+
connID := uuid.NewString()
101+
testURL := setupMockTokenServer(t, map[string]string{
102+
"/teams/the-team/syncs/test-connections/" + testID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`,
103+
})
104+
setEnvs(t, map[string]string{
105+
"CQ_CLOUD": "1",
106+
"CLOUDQUERY_API_URL": testURL,
107+
"CLOUDQUERY_API_KEY": testAPIKey,
108+
"_CQ_TEAM_NAME": "the-team",
109+
"_CQ_SYNC_TEST_CONNECTION_ID": testID,
110+
"_CQ_CONNECTOR_ID": connID,
111+
})
112+
r := require.New(t)
113+
tok, err := NewTokenSource(WithAccessToken("token", "bearer", time.Time{}))
114+
r.NoError(err)
115+
tk, err := tok.Token()
116+
r.NoError(err)
117+
r.True(tk.Valid())
118+
r.Equal("new-token", tk.AccessToken)
119+
}
120+
121+
func setEnvs(t *testing.T, envs map[string]string) {
122+
t.Helper()
123+
for k, v := range envs {
124+
t.Setenv(k, v)
125+
}
126+
}
127+
128+
func setupMockTokenServer(t *testing.T, responses map[string]string) string {
129+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
130+
if a := r.Header.Get("Authorization"); a != "Bearer "+testAPIKey {
131+
w.WriteHeader(http.StatusUnauthorized)
132+
return
133+
}
134+
135+
resp, ok := responses[r.URL.Path]
136+
if !ok {
137+
w.WriteHeader(http.StatusNotFound)
138+
return
139+
}
140+
141+
w.Header().Set("Content-Type", "application/json")
142+
w.WriteHeader(http.StatusOK)
143+
w.Write([]byte(resp))
144+
}))
145+
t.Cleanup(func() {
146+
ts.Close()
147+
})
148+
return ts.URL
149+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package remoteoauth
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"golang.org/x/oauth2"
8+
)
9+
10+
type TokenSourceOption func(source *tokenSource)
11+
12+
func WithAccessToken(token, tokenType string, expiry time.Time) TokenSourceOption {
13+
return func(t *tokenSource) {
14+
t.currentToken = oauth2.Token{
15+
AccessToken: token,
16+
TokenType: tokenType,
17+
Expiry: expiry,
18+
}
19+
}
20+
}
21+
22+
func WithDefaultContext(ctx context.Context) TokenSourceOption {
23+
return func(t *tokenSource) {
24+
t.defaultContext = ctx
25+
}
26+
}
27+
28+
func withNoWrap() TokenSourceOption {
29+
return func(t *tokenSource) {
30+
t.noWrap = true
31+
}
32+
}

0 commit comments

Comments
 (0)