Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 4 additions & 172 deletions helpers/remoteoauth/token.go
Original file line number Diff line number Diff line change
@@ -1,187 +1,19 @@
package remoteoauth

import (
"context"
"errors"
"fmt"
"net/http"
"os"

cloudquery_api "github.com/cloudquery/cloudquery-api-go"
"github.com/google/uuid"
"golang.org/x/oauth2"
)

// NewTokenSource creates a new token source.
// Deprecated: Use oauth2.StaticTokenSource directly instead.
func NewTokenSource(opts ...TokenSourceOption) (oauth2.TokenSource, error) {
t := &tokenSource{}
for _, opt := range opts {
opt(t)
}

if _, cloudEnabled := os.LookupEnv("CQ_CLOUD"); !cloudEnabled {
return oauth2.StaticTokenSource(&t.currentToken), nil
}

cloudToken, err := newCloudTokenSource(t.defaultContext)
if err != nil {
return nil, err
}
if t.noWrap {
return cloudToken, nil
}

return oauth2.ReuseTokenSource(nil, cloudToken), nil
return oauth2.StaticTokenSource(&t.currentToken), nil
}

type tokenSource struct {
defaultContext context.Context
currentToken oauth2.Token
noWrap bool
}

type cloudTokenSource struct {
defaultContext context.Context
apiClient *cloudquery_api.ClientWithResponses

apiURL string
apiToken string
teamName string
syncName string
testConnUUID uuid.UUID
syncRunUUID uuid.UUID
connectorUUID uuid.UUID
isTestConnection bool
}

var _ oauth2.TokenSource = (*cloudTokenSource)(nil)

func newCloudTokenSource(defaultContext context.Context) (oauth2.TokenSource, error) {
t := &cloudTokenSource{
defaultContext: defaultContext,
}
if t.defaultContext == nil {
t.defaultContext = context.Background()
}

err := t.initCloudOpts()
if err != nil {
return nil, err
}

t.apiClient, err = cloudquery_api.NewClientWithResponses(t.apiURL,
cloudquery_api.WithRequestEditorFn(func(_ context.Context, req *http.Request) error {
req.Header.Set("Authorization", "Bearer "+t.apiToken)
return nil
}))
if err != nil {
return nil, fmt.Errorf("failed to create api client: %w", err)
}

return t, nil
}

// Token returns a new token from the remote source using the default context.
func (t *cloudTokenSource) Token() (*oauth2.Token, error) {
return t.retrieveToken(t.defaultContext)
}

func (t *cloudTokenSource) retrieveToken(ctx context.Context) (*oauth2.Token, error) {
var oauthResp *cloudquery_api.ConnectorCredentialsResponseOAuth
if !t.isTestConnection {
resp, err := t.apiClient.GetSyncRunConnectorCredentialsWithResponse(ctx, t.teamName, t.syncName, t.syncRunUUID, t.connectorUUID)
if err != nil {
return nil, fmt.Errorf("failed to get sync run connector credentials: %w", err)
}
if resp.StatusCode() != http.StatusOK {
if resp.JSON422 != nil {
return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.JSON422.Message)
}
return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.Status())
}
oauthResp = resp.JSON200.Oauth
} else {
resp, err := t.apiClient.GetTestConnectionConnectorCredentialsWithResponse(ctx, t.teamName, t.testConnUUID, t.connectorUUID)
if err != nil {
return nil, fmt.Errorf("failed to get test connection connector credentials: %w", err)
}
if resp.StatusCode() != http.StatusOK {
if resp.JSON422 != nil {
return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.JSON422.Message)
}
return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.Status())
}
oauthResp = resp.JSON200.Oauth
}

if oauthResp == nil {
return nil, errors.New("missing oauth credentials in response")
}

tok := &oauth2.Token{
AccessToken: oauthResp.AccessToken,
}
if oauthResp.Expires != nil {
tok.Expiry = *oauthResp.Expires
}
return tok, nil
}

func (t *cloudTokenSource) initCloudOpts() error {
var allErr error

t.apiToken = os.Getenv("CLOUDQUERY_API_KEY")
if t.apiToken == "" {
allErr = errors.Join(allErr, errors.New("CLOUDQUERY_API_KEY missing"))
}
t.apiURL = os.Getenv("CLOUDQUERY_API_URL")
if t.apiURL == "" {
t.apiURL = "https://api.cloudquery.io"
}

t.teamName = os.Getenv("_CQ_TEAM_NAME")
if t.teamName == "" {
allErr = errors.Join(allErr, errors.New("_CQ_TEAM_NAME missing"))
}
t.syncName = os.Getenv("_CQ_SYNC_NAME")
syncRunID := os.Getenv("_CQ_SYNC_RUN_ID")
testConnID := os.Getenv("_CQ_SYNC_TEST_CONNECTION_ID")
if testConnID == "" && syncRunID == "" {
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID or _CQ_SYNC_RUN_ID missing"))
} else if testConnID != "" && syncRunID != "" {
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID and _CQ_SYNC_RUN_ID are mutually exclusive"))
}

var err error
if syncRunID != "" {
if t.syncName == "" {
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME missing"))
}

t.syncRunUUID, err = uuid.Parse(syncRunID)
if err != nil {
allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_RUN_ID is not a valid UUID: %w", err))
}
}
if testConnID != "" {
if t.syncName != "" {
allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME should be empty"))
}

t.testConnUUID, err = uuid.Parse(testConnID)
if err != nil {
allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_TEST_CONNECTION_ID is not a valid UUID: %w", err))
}
t.isTestConnection = true
}

connectorID := os.Getenv("_CQ_CONNECTOR_ID")
if connectorID == "" {
allErr = errors.Join(allErr, errors.New("_CQ_CONNECTOR_ID missing"))
} else {
t.connectorUUID, err = uuid.Parse(connectorID)
if err != nil {
allErr = errors.Join(allErr, fmt.Errorf("_CQ_CONNECTOR_ID is not a valid UUID: %w", err))
}
}
return allErr
currentToken oauth2.Token
}
128 changes: 0 additions & 128 deletions helpers/remoteoauth/token_test.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
package remoteoauth

import (
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

const testAPIKey = "test-key"

func TestLocalTokenAccess(t *testing.T) {
r := require.New(t)
_, cloud := os.LookupEnv("CQ_CLOUD")
Expand All @@ -37,126 +32,3 @@ func TestLocalTokenAccessWithDeprecatedTokenOpt(t *testing.T) {
r.True(tk.Valid())
r.Equal("token", tk.AccessToken)
}

func TestFirstLocalTokenAccess(t *testing.T) {
runID := uuid.NewString()
connID := uuid.NewString()
testURL := setupMockTokenServer(t, map[string]string{
"/teams/the-team/syncs/the-sync/runs/" + runID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`,
})
setEnvs(t, map[string]string{
"CQ_CLOUD": "1",
"CLOUDQUERY_API_URL": testURL,
"CLOUDQUERY_API_KEY": testAPIKey,
"_CQ_TEAM_NAME": "the-team",
"_CQ_SYNC_NAME": "the-sync",
"_CQ_SYNC_RUN_ID": runID,
"_CQ_CONNECTOR_ID": connID,
})
r := require.New(t)
tok, err := NewTokenSource(WithToken(oauth2.Token{AccessToken: "token"}))
r.NoError(err)
tk, err := tok.Token()
r.NoError(err)
r.True(tk.Valid())
r.Equal("new-token", tk.AccessToken)
}

func TestInvalidAPIKeyTokenAccess(t *testing.T) {
runID := uuid.NewString()
connID := uuid.NewString()
testURL := setupMockTokenServer(t, nil)
setEnvs(t, map[string]string{
"CQ_CLOUD": "1",
"CLOUDQUERY_API_URL": testURL,
"CLOUDQUERY_API_KEY": "invalid",
"_CQ_TEAM_NAME": "the-team",
"_CQ_SYNC_NAME": "the-sync",
"_CQ_SYNC_RUN_ID": runID,
"_CQ_CONNECTOR_ID": connID,
})
r := require.New(t)
tok, err := NewTokenSource(WithToken(oauth2.Token{AccessToken: "token"}))
r.NoError(err)
tk, err := tok.Token()
r.Nil(tk)
r.False(tk.Valid())
r.ErrorContains(err, "failed to get sync run connector credentials")
}

func TestSyncRunTokenAccess(t *testing.T) {
runID := uuid.NewString()
connID := uuid.NewString()
testURL := setupMockTokenServer(t, map[string]string{
"/teams/the-team/syncs/the-sync/runs/" + runID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`,
})
setEnvs(t, map[string]string{
"CQ_CLOUD": "1",
"CLOUDQUERY_API_URL": testURL,
"CLOUDQUERY_API_KEY": testAPIKey,
"_CQ_TEAM_NAME": "the-team",
"_CQ_SYNC_NAME": "the-sync",
"_CQ_SYNC_RUN_ID": runID,
"_CQ_CONNECTOR_ID": connID,
})
r := require.New(t)
tok, err := NewTokenSource()
r.NoError(err)
tk, err := tok.Token()
r.NoError(err)
r.True(tk.Valid())
r.Equal("new-token", tk.AccessToken)
}

func TestTestConnectionTokenAccess(t *testing.T) {
testID := uuid.NewString()
connID := uuid.NewString()
testURL := setupMockTokenServer(t, map[string]string{
"/teams/the-team/syncs/test-connections/" + testID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`,
})
setEnvs(t, map[string]string{
"CQ_CLOUD": "1",
"CLOUDQUERY_API_URL": testURL,
"CLOUDQUERY_API_KEY": testAPIKey,
"_CQ_TEAM_NAME": "the-team",
"_CQ_SYNC_TEST_CONNECTION_ID": testID,
"_CQ_CONNECTOR_ID": connID,
})
r := require.New(t)
tok, err := NewTokenSource(WithToken(oauth2.Token{AccessToken: "token"}))
r.NoError(err)
tk, err := tok.Token()
r.NoError(err)
r.True(tk.Valid())
r.Equal("new-token", tk.AccessToken)
}

func setEnvs(t *testing.T, envs map[string]string) {
t.Helper()
for k, v := range envs {
t.Setenv(k, v)
}
}

func setupMockTokenServer(t *testing.T, responses map[string]string) string {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if a := r.Header.Get("Authorization"); a != "Bearer "+testAPIKey {
w.WriteHeader(http.StatusUnauthorized)
return
}

resp, ok := responses[r.URL.Path]
if !ok {
w.WriteHeader(http.StatusNotFound)
return
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(resp))
}))
t.Cleanup(func() {
ts.Close()
})
return ts.URL
}
14 changes: 4 additions & 10 deletions helpers/remoteoauth/tokenoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,15 @@ func WithAccessToken(token, tokenType string, expiry time.Time) TokenSourceOptio
}

// WithToken sets the default token for the token source.
// Deprecated: Use oauth2.StaticTokenSource directly instead.
func WithToken(token oauth2.Token) TokenSourceOption {
return func(t *tokenSource) {
t.currentToken = token
}
}

// WithDefaultContext sets the default context for the token source, used when creating a new token request.
func WithDefaultContext(ctx context.Context) TokenSourceOption {
return func(t *tokenSource) {
t.defaultContext = ctx
}
}

func withNoWrap() TokenSourceOption {
return func(t *tokenSource) {
t.noWrap = true
}
// Deprecated: not used in the current implementation.
func WithDefaultContext(_ context.Context) TokenSourceOption {
return func(*tokenSource) {}
}
Loading