Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 78 additions & 6 deletions go/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"io"
"net/http"
"net/url"
"os"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -266,10 +267,13 @@ func (c *connectionImpl) GetTablesForDBSchema(ctx context.Context, catalog strin
}

type bigQueryTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorURI string `json:"error_uri"`
}

// GetCurrentCatalog implements driverbase.CurrentNamespacer.
Expand Down Expand Up @@ -712,9 +716,36 @@ func (c *connectionImpl) newClient(ctx context.Context) error {
// First, establish base authentication
switch c.authType {
case OptionValueAuthTypeJSONCredentialFile:
authOptions = append(authOptions, option.WithAuthCredentialsFile(c.credentialsType, c.credentials))
credType := c.credentialsType
if credType == "" {
// Auto-detect credential type from the JSON file
detected, err := detectCredentialTypeFromFile(c.credentials)
if err != nil {
return adbc.Error{
Code: adbc.StatusInvalidArgument,
Msg: fmt.Sprintf("[bq] failed to detect credential type from file %q: %s", c.credentials, err.Error()),
}
}
credType = detected
}
c.credentialsType = credType
c.Logger.Debug("Using JSON credential file", "file", c.credentials, "credentialType", string(credType))
authOptions = append(authOptions, option.WithAuthCredentialsFile(credType, c.credentials))
case OptionValueAuthTypeJSONCredentialString:
authOptions = append(authOptions, option.WithAuthCredentialsJSON(c.credentialsType, []byte(c.credentials)))
credType := c.credentialsType
if credType == "" {
detected, err := detectCredentialTypeFromJSON([]byte(c.credentials))
if err != nil {
return adbc.Error{
Code: adbc.StatusInvalidArgument,
Msg: fmt.Sprintf("[bq] failed to detect credential type from JSON: %s", err.Error()),
}
}
credType = detected
}
c.credentialsType = credType
c.Logger.Debug("Using JSON credential string", "credentialType", string(credType))
authOptions = append(authOptions, option.WithAuthCredentialsJSON(credType, []byte(c.credentials)))
case OptionValueAuthTypeUserAuthentication:
if c.clientID == "" {
return adbc.Error{
Expand All @@ -734,8 +765,10 @@ func (c *connectionImpl) newClient(ctx context.Context) error {
Msg: fmt.Sprintf("[bq] `%s` parameter is empty", OptionStringAuthRefreshToken),
}
}
c.Logger.Debug("Using user OAuth authentication")
authOptions = append(authOptions, option.WithTokenSource(c))
case OptionValueAuthTypeAppDefaultCredentials, OptionValueAuthTypeDefault, "":
c.Logger.Debug("Using Application Default Credentials (ADC)", "authType", c.authType)
// Use Application Default Credentials (default behavior)
// No additional options needed - ADC is used by default
default:
Expand Down Expand Up @@ -1146,5 +1179,44 @@ func (c *connectionImpl) getAccessToken() (*bigQueryTokenResponse, error) {
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err, "get access token")
}

if tokenResponse.Error != "" {
msg := fmt.Sprintf("[bq] OAuth token error: %s", tokenResponse.Error)
if tokenResponse.ErrorDescription != "" {
msg += ": " + tokenResponse.ErrorDescription
}
if isReauthError(tokenResponse.ErrorDescription) {
msg += ". " + reauthGuidance
}
if tokenResponse.ErrorURI != "" {
msg += " (see: " + tokenResponse.ErrorURI + ")"
}
return nil, adbc.Error{
Code: adbc.StatusUnauthorized,
Msg: msg,
}
}

return &tokenResponse, nil
}

func detectCredentialTypeFromFile(filename string) (option.CredentialsType, error) {
data, err := os.ReadFile(filename)
if err != nil {
return "", fmt.Errorf("read credential file: %w", err)
}
return detectCredentialTypeFromJSON(data)
}

func detectCredentialTypeFromJSON(data []byte) (option.CredentialsType, error) {
var f struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &f); err != nil {
return "", fmt.Errorf("parse credential JSON: %w", err)
}
if f.Type == "" {
return "", fmt.Errorf("missing 'type' field in credential JSON")
}
return option.CredentialsType(f.Type), nil
}
15 changes: 15 additions & 0 deletions go/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,24 @@ func errToAdbcErr(defaultStatus adbc.Status, err error, errContext string, conte
}

adbcErr.Msg = msg.String()

if isReauthError(err.Error()) {
adbcErr.Code = adbc.StatusUnauthorized
adbcErr.Msg += ". " + reauthGuidance
}

return adbcErr
}

const reauthGuidance = "Your Google Workspace admin requires re-authentication (RAPT). " +
"Consider using a service account instead of user credentials, or re-authenticate " +
"interactively with 'gcloud auth application-default login'. " +
"See https://support.google.com/a/answer/9368756"

func isReauthError(s string) bool {
return strings.Contains(s, "invalid_rapt") || strings.Contains(s, "reauth related error")
}

func retryWithBackoff(ctx context.Context, context string, maxAttempts int, backoff gax.Backoff, f func() (bool, error)) error {
attempt := 0
for {
Expand Down
Loading