diff --git a/helpers/remoteoauth/token.go b/helpers/remoteoauth/token.go index 667fef4abd..02e67a8f62 100644 --- a/helpers/remoteoauth/token.go +++ b/helpers/remoteoauth/token.go @@ -1,187 +1,19 @@ package remoteoauth import ( - "context" - "errors" - "fmt" - "net/http" - "os" - - cloudquery_api "github.com/cloudquery/cloudquery-api-go" - "github.com/google/uuid" "golang.org/x/oauth2" ) +// NewTokenSource creates a new token source. +// Deprecated: Use oauth2.StaticTokenSource directly instead. func NewTokenSource(opts ...TokenSourceOption) (oauth2.TokenSource, error) { t := &tokenSource{} for _, opt := range opts { opt(t) } - - if _, cloudEnabled := os.LookupEnv("CQ_CLOUD"); !cloudEnabled { - return oauth2.StaticTokenSource(&t.currentToken), nil - } - - cloudToken, err := newCloudTokenSource(t.defaultContext) - if err != nil { - return nil, err - } - if t.noWrap { - return cloudToken, nil - } - - return oauth2.ReuseTokenSource(nil, cloudToken), nil + return oauth2.StaticTokenSource(&t.currentToken), nil } type tokenSource struct { - defaultContext context.Context - currentToken oauth2.Token - noWrap bool -} - -type cloudTokenSource struct { - defaultContext context.Context - apiClient *cloudquery_api.ClientWithResponses - - apiURL string - apiToken string - teamName string - syncName string - testConnUUID uuid.UUID - syncRunUUID uuid.UUID - connectorUUID uuid.UUID - isTestConnection bool -} - -var _ oauth2.TokenSource = (*cloudTokenSource)(nil) - -func newCloudTokenSource(defaultContext context.Context) (oauth2.TokenSource, error) { - t := &cloudTokenSource{ - defaultContext: defaultContext, - } - if t.defaultContext == nil { - t.defaultContext = context.Background() - } - - err := t.initCloudOpts() - if err != nil { - return nil, err - } - - t.apiClient, err = cloudquery_api.NewClientWithResponses(t.apiURL, - cloudquery_api.WithRequestEditorFn(func(_ context.Context, req *http.Request) error { - req.Header.Set("Authorization", "Bearer "+t.apiToken) - return nil - })) - if err != nil { - return nil, fmt.Errorf("failed to create api client: %w", err) - } - - return t, nil -} - -// Token returns a new token from the remote source using the default context. -func (t *cloudTokenSource) Token() (*oauth2.Token, error) { - return t.retrieveToken(t.defaultContext) -} - -func (t *cloudTokenSource) retrieveToken(ctx context.Context) (*oauth2.Token, error) { - var oauthResp *cloudquery_api.ConnectorCredentialsResponseOAuth - if !t.isTestConnection { - resp, err := t.apiClient.GetSyncRunConnectorCredentialsWithResponse(ctx, t.teamName, t.syncName, t.syncRunUUID, t.connectorUUID) - if err != nil { - return nil, fmt.Errorf("failed to get sync run connector credentials: %w", err) - } - if resp.StatusCode() != http.StatusOK { - if resp.JSON422 != nil { - return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.JSON422.Message) - } - return nil, fmt.Errorf("failed to get sync run connector credentials: %s", resp.Status()) - } - oauthResp = resp.JSON200.Oauth - } else { - resp, err := t.apiClient.GetTestConnectionConnectorCredentialsWithResponse(ctx, t.teamName, t.testConnUUID, t.connectorUUID) - if err != nil { - return nil, fmt.Errorf("failed to get test connection connector credentials: %w", err) - } - if resp.StatusCode() != http.StatusOK { - if resp.JSON422 != nil { - return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.JSON422.Message) - } - return nil, fmt.Errorf("failed to get test connection connector credentials: %s", resp.Status()) - } - oauthResp = resp.JSON200.Oauth - } - - if oauthResp == nil { - return nil, errors.New("missing oauth credentials in response") - } - - tok := &oauth2.Token{ - AccessToken: oauthResp.AccessToken, - } - if oauthResp.Expires != nil { - tok.Expiry = *oauthResp.Expires - } - return tok, nil -} - -func (t *cloudTokenSource) initCloudOpts() error { - var allErr error - - t.apiToken = os.Getenv("CLOUDQUERY_API_KEY") - if t.apiToken == "" { - allErr = errors.Join(allErr, errors.New("CLOUDQUERY_API_KEY missing")) - } - t.apiURL = os.Getenv("CLOUDQUERY_API_URL") - if t.apiURL == "" { - t.apiURL = "https://api.cloudquery.io" - } - - t.teamName = os.Getenv("_CQ_TEAM_NAME") - if t.teamName == "" { - allErr = errors.Join(allErr, errors.New("_CQ_TEAM_NAME missing")) - } - t.syncName = os.Getenv("_CQ_SYNC_NAME") - syncRunID := os.Getenv("_CQ_SYNC_RUN_ID") - testConnID := os.Getenv("_CQ_SYNC_TEST_CONNECTION_ID") - if testConnID == "" && syncRunID == "" { - allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID or _CQ_SYNC_RUN_ID missing")) - } else if testConnID != "" && syncRunID != "" { - allErr = errors.Join(allErr, errors.New("_CQ_SYNC_TEST_CONNECTION_ID and _CQ_SYNC_RUN_ID are mutually exclusive")) - } - - var err error - if syncRunID != "" { - if t.syncName == "" { - allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME missing")) - } - - t.syncRunUUID, err = uuid.Parse(syncRunID) - if err != nil { - allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_RUN_ID is not a valid UUID: %w", err)) - } - } - if testConnID != "" { - if t.syncName != "" { - allErr = errors.Join(allErr, errors.New("_CQ_SYNC_NAME should be empty")) - } - - t.testConnUUID, err = uuid.Parse(testConnID) - if err != nil { - allErr = errors.Join(allErr, fmt.Errorf("_CQ_SYNC_TEST_CONNECTION_ID is not a valid UUID: %w", err)) - } - t.isTestConnection = true - } - - connectorID := os.Getenv("_CQ_CONNECTOR_ID") - if connectorID == "" { - allErr = errors.Join(allErr, errors.New("_CQ_CONNECTOR_ID missing")) - } else { - t.connectorUUID, err = uuid.Parse(connectorID) - if err != nil { - allErr = errors.Join(allErr, fmt.Errorf("_CQ_CONNECTOR_ID is not a valid UUID: %w", err)) - } - } - return allErr + currentToken oauth2.Token } diff --git a/helpers/remoteoauth/token_test.go b/helpers/remoteoauth/token_test.go index 58af61d85c..f7402258fa 100644 --- a/helpers/remoteoauth/token_test.go +++ b/helpers/remoteoauth/token_test.go @@ -1,19 +1,14 @@ package remoteoauth import ( - "net/http" - "net/http/httptest" "os" "testing" "time" - "github.com/google/uuid" "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) -const testAPIKey = "test-key" - func TestLocalTokenAccess(t *testing.T) { r := require.New(t) _, cloud := os.LookupEnv("CQ_CLOUD") @@ -37,126 +32,3 @@ func TestLocalTokenAccessWithDeprecatedTokenOpt(t *testing.T) { r.True(tk.Valid()) r.Equal("token", tk.AccessToken) } - -func TestFirstLocalTokenAccess(t *testing.T) { - runID := uuid.NewString() - connID := uuid.NewString() - testURL := setupMockTokenServer(t, map[string]string{ - "/teams/the-team/syncs/the-sync/runs/" + runID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`, - }) - setEnvs(t, map[string]string{ - "CQ_CLOUD": "1", - "CLOUDQUERY_API_URL": testURL, - "CLOUDQUERY_API_KEY": testAPIKey, - "_CQ_TEAM_NAME": "the-team", - "_CQ_SYNC_NAME": "the-sync", - "_CQ_SYNC_RUN_ID": runID, - "_CQ_CONNECTOR_ID": connID, - }) - r := require.New(t) - tok, err := NewTokenSource(WithToken(oauth2.Token{AccessToken: "token"})) - r.NoError(err) - tk, err := tok.Token() - r.NoError(err) - r.True(tk.Valid()) - r.Equal("new-token", tk.AccessToken) -} - -func TestInvalidAPIKeyTokenAccess(t *testing.T) { - runID := uuid.NewString() - connID := uuid.NewString() - testURL := setupMockTokenServer(t, nil) - setEnvs(t, map[string]string{ - "CQ_CLOUD": "1", - "CLOUDQUERY_API_URL": testURL, - "CLOUDQUERY_API_KEY": "invalid", - "_CQ_TEAM_NAME": "the-team", - "_CQ_SYNC_NAME": "the-sync", - "_CQ_SYNC_RUN_ID": runID, - "_CQ_CONNECTOR_ID": connID, - }) - r := require.New(t) - tok, err := NewTokenSource(WithToken(oauth2.Token{AccessToken: "token"})) - r.NoError(err) - tk, err := tok.Token() - r.Nil(tk) - r.False(tk.Valid()) - r.ErrorContains(err, "failed to get sync run connector credentials") -} - -func TestSyncRunTokenAccess(t *testing.T) { - runID := uuid.NewString() - connID := uuid.NewString() - testURL := setupMockTokenServer(t, map[string]string{ - "/teams/the-team/syncs/the-sync/runs/" + runID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`, - }) - setEnvs(t, map[string]string{ - "CQ_CLOUD": "1", - "CLOUDQUERY_API_URL": testURL, - "CLOUDQUERY_API_KEY": testAPIKey, - "_CQ_TEAM_NAME": "the-team", - "_CQ_SYNC_NAME": "the-sync", - "_CQ_SYNC_RUN_ID": runID, - "_CQ_CONNECTOR_ID": connID, - }) - r := require.New(t) - tok, err := NewTokenSource() - r.NoError(err) - tk, err := tok.Token() - r.NoError(err) - r.True(tk.Valid()) - r.Equal("new-token", tk.AccessToken) -} - -func TestTestConnectionTokenAccess(t *testing.T) { - testID := uuid.NewString() - connID := uuid.NewString() - testURL := setupMockTokenServer(t, map[string]string{ - "/teams/the-team/syncs/test-connections/" + testID + "/connector/" + connID + "/credentials": `{"oauth":{"access_token":"new-token"}}`, - }) - setEnvs(t, map[string]string{ - "CQ_CLOUD": "1", - "CLOUDQUERY_API_URL": testURL, - "CLOUDQUERY_API_KEY": testAPIKey, - "_CQ_TEAM_NAME": "the-team", - "_CQ_SYNC_TEST_CONNECTION_ID": testID, - "_CQ_CONNECTOR_ID": connID, - }) - r := require.New(t) - tok, err := NewTokenSource(WithToken(oauth2.Token{AccessToken: "token"})) - r.NoError(err) - tk, err := tok.Token() - r.NoError(err) - r.True(tk.Valid()) - r.Equal("new-token", tk.AccessToken) -} - -func setEnvs(t *testing.T, envs map[string]string) { - t.Helper() - for k, v := range envs { - t.Setenv(k, v) - } -} - -func setupMockTokenServer(t *testing.T, responses map[string]string) string { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if a := r.Header.Get("Authorization"); a != "Bearer "+testAPIKey { - w.WriteHeader(http.StatusUnauthorized) - return - } - - resp, ok := responses[r.URL.Path] - if !ok { - w.WriteHeader(http.StatusNotFound) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(resp)) - })) - t.Cleanup(func() { - ts.Close() - }) - return ts.URL -} diff --git a/helpers/remoteoauth/tokenoptions.go b/helpers/remoteoauth/tokenoptions.go index 95b3fa2f43..5595e87d4d 100644 --- a/helpers/remoteoauth/tokenoptions.go +++ b/helpers/remoteoauth/tokenoptions.go @@ -22,6 +22,7 @@ func WithAccessToken(token, tokenType string, expiry time.Time) TokenSourceOptio } // WithToken sets the default token for the token source. +// Deprecated: Use oauth2.StaticTokenSource directly instead. func WithToken(token oauth2.Token) TokenSourceOption { return func(t *tokenSource) { t.currentToken = token @@ -29,14 +30,7 @@ func WithToken(token oauth2.Token) TokenSourceOption { } // WithDefaultContext sets the default context for the token source, used when creating a new token request. -func WithDefaultContext(ctx context.Context) TokenSourceOption { - return func(t *tokenSource) { - t.defaultContext = ctx - } -} - -func withNoWrap() TokenSourceOption { - return func(t *tokenSource) { - t.noWrap = true - } +// Deprecated: not used in the current implementation. +func WithDefaultContext(_ context.Context) TokenSourceOption { + return func(*tokenSource) {} } diff --git a/helpers/remoteoauth/tokenoptions_test.go b/helpers/remoteoauth/tokenoptions_test.go deleted file mode 100644 index 7873e98a70..0000000000 --- a/helpers/remoteoauth/tokenoptions_test.go +++ /dev/null @@ -1,124 +0,0 @@ -package remoteoauth - -import ( - "reflect" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/require" -) - -func TestInitCloudOpts(t *testing.T) { - validUUID := uuid.NewString() - - cases := []struct { - name string - envs map[string]string - expectError bool - expectCloud bool - expectTestConn bool - }{ - { - name: "no envs", - }, - { - name: "cloud env", - envs: map[string]string{ - "CQ_CLOUD": "1", - "CLOUDQUERY_API_KEY": "the-key", - "_CQ_TEAM_NAME": "the-team", - "_CQ_SYNC_NAME": "the-sync", - "_CQ_SYNC_RUN_ID": validUUID, - "_CQ_CONNECTOR_ID": validUUID, - }, - expectCloud: true, - }, - { - name: "cloud env test conn", - envs: map[string]string{ - "CQ_CLOUD": "1", - "CLOUDQUERY_API_KEY": "the-key", - "_CQ_TEAM_NAME": "the-team", - "_CQ_SYNC_TEST_CONNECTION_ID": validUUID, - "_CQ_CONNECTOR_ID": validUUID, - }, - expectCloud: true, - expectTestConn: true, - }, - { - name: "missing cq_cloud with everything set", - envs: map[string]string{ - "CLOUDQUERY_API_KEY": "the-key", - "_CQ_TEAM_NAME": "the-team", - "_CQ_SYNC_NAME": "the-sync", - "_CQ_SYNC_RUN_ID": validUUID, - "_CQ_CONNECTOR_ID": validUUID, - }, - expectCloud: false, - }, - { - name: "missing cq_cloud with missing api key", - envs: map[string]string{ - "_CQ_TEAM_NAME": "the-team", - "_CQ_SYNC_NAME": "the-sync", - "_CQ_SYNC_RUN_ID": validUUID, - "_CQ_CONNECTOR_ID": validUUID, - }, - expectCloud: false, - }, - { - name: "missing cq_cloud with missing sync name", - envs: map[string]string{ - "CLOUDQUERY_API_KEY": "the-key", - "_CQ_TEAM_NAME": "the-team", - "_CQ_SYNC_NAME": "the-sync", - "_CQ_SYNC_RUN_ID": validUUID, - "_CQ_CONNECTOR_ID": validUUID, - }, - expectCloud: false, - }, - { - name: "cloud env missing api key", - envs: map[string]string{ - "CQ_CLOUD": "1", - "_CQ_TEAM_NAME": "the-team", - "_CQ_SYNC_TEST_CONNECTION_ID": validUUID, - "_CQ_CONNECTOR_ID": validUUID, - }, - expectError: true, - }, - { - name: "cloud env missing sync name", - envs: map[string]string{ - "CQ_CLOUD": "1", - "CLOUDQUERY_API_KEY": "the-key", - "_CQ_TEAM_NAME": "the-team", - "_CQ_SYNC_RUN_ID": validUUID, - "_CQ_CONNECTOR_ID": validUUID, - }, - expectError: true, - }, - } - for _, tc := range cases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - r := require.New(t) - for k, v := range tc.envs { - t.Setenv(k, v) - } - tok, err := NewTokenSource(withNoWrap()) - if tc.expectError { - r.Error(err) - return - } - r.NoError(err) - if tc.expectCloud { - ts := tok.(*cloudTokenSource) - r.Equal(tc.expectTestConn, ts.isTestConnection) - return - } - rt := reflect.TypeOf(tok) - r.Equal("oauth2.staticTokenSource", rt.String()) - }) - } -}