diff --git a/.github/workflows/capabilities_and_config.yaml b/.github/workflows/capabilities_and_config.yaml index c7ea94e4..a410d6fb 100644 --- a/.github/workflows/capabilities_and_config.yaml +++ b/.github/workflows/capabilities_and_config.yaml @@ -27,15 +27,18 @@ jobs: - name: Run and save config output run: ./connector config > config_schema.json + - name: Setup private key + run: | + echo "${{ secrets.BATON_PRIVATE_KEY }}" | base64 -d > /tmp/snowflake_key.p8 + chmod 600 /tmp/snowflake_key.p8 + - name: Run and save capabilities output env: BATON_ACCOUNT_IDENTIFIER: example BATON_USER_IDENTIFIER: example BATON_ACCOUNT_URL: https://example.snowflakecomputing.com - BATON_PRIVATE_KEY_PATH: ${{ runner.temp }}/baton_private_key.pem - run: | - openssl genrsa -out "$BATON_PRIVATE_KEY_PATH" 2048 - ./connector --sync-secrets capabilities > baton_capabilities.json + BATON_PRIVATE_KEY_PATH: /tmp/snowflake_key.p8 + run: ./connector --sync-secrets capabilities > baton_capabilities.json - name: Commit changes uses: EndBug/add-and-commit@v9 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index dc035996..61539aec 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -43,19 +43,13 @@ jobs: test: runs-on: ubuntu-latest - # Define any services needed for the test suite (or delete this section) - # services: - # postgres: - # image: postgres:16 - # ports: - # - "5432:5432" - # env: - # POSTGRES_PASSWORD: secretpassword + if: github.event_name == 'pull_request' || github.ref == 'refs/heads/main' env: BATON_LOG_LEVEL: debug - # Add any environment variables needed to run baton-snowflake - # BATON_BASE_URL: 'http://localhost:8080' - # BATON_ACCESS_TOKEN: 'secret_token' + # Snowflake connection configuration + BATON_ACCOUNT_IDENTIFIER: ${{ secrets.BATON_ACCOUNT_IDENTIFIER }} + BATON_USER_IDENTIFIER: ${{ secrets.BATON_USER_IDENTIFIER }} + BATON_ACCOUNT_URL: ${{ secrets.BATON_ACCOUNT_URL }} # The following parameters are passed to grant/revoke commands # Change these to the correct IDs for your test data CONNECTOR_GRANT: 'grant:entitlement:group:1234:member:user:9876' @@ -69,37 +63,29 @@ jobs: uses: actions/setup-go@v5 with: go-version-file: 'go.mod' - # Install any dependencies here (or delete this) - # - name: Install postgres client - # run: sudo apt install postgresql-client - # Run any fixture setup here (or delete this) - # - name: Import sql into postgres - # run: psql -h localhost --user postgres -f environment.sql - # env: - # PGPASSWORD: secretpassword + - name: Build baton-snowflake - run: go build ./cmd/baton-snowflake + run: go build -o baton-snowflake ./cmd/baton-snowflake # - name: Run baton-snowflake # run: ./baton-snowflake - name: Install baton run: ./scripts/get-baton.sh && mv baton /usr/local/bin - # - name: Check for grant before revoking - # run: - # baton grants --entitlement="${{ env.CONNECTOR_ENTITLEMENT }}" --output-format=json | jq --exit-status ".grants[].principal.id.resource == \"${{ env.CONNECTOR_PRINCIPAL }}\"" - - # - name: Revoke grants - # run: ./baton-snowflake --revoke-grant="${{ env.CONNECTOR_GRANT }}" - - # - name: Check grant was revoked - # run: ./baton-snowflake && baton grants --entitlement="${{ env.CONNECTOR_ENTITLEMENT }}" --output-format=json | jq --exit-status "if .grants then .grants[]?.principal.id.resource != \"${{ env.CONNECTOR_PRINCIPAL }}\" else . end" + - name: Setup private key + run: | + echo "${{ secrets.BATON_PRIVATE_KEY }}" | base64 -d > /tmp/snowflake_key.p8 + chmod 600 /tmp/snowflake_key.p8 - # - name: Grant entitlement - # # Change the grant arguments to the correct IDs for your test data - # run: ./baton-snowflake --grant-entitlement="${{ env.CONNECTOR_ENTITLEMENT }}" --grant-principal="${{ env.CONNECTOR_PRINCIPAL }}" --grant-principal-type="${{ env.CONNECTOR_PRINCIPAL_TYPE }}" - - # - name: Check grant was re-granted - # run: - # baton grants --entitlement="${{ env.CONNECTOR_ENTITLEMENT }}" --output-format=json | jq --exit-status ".grants[].principal.id.resource == \"${{ env.CONNECTOR_PRINCIPAL }}\"" + - name: Test Account Provisioning + uses: ConductorOne/github-workflows/actions/account-provisioning@v4 + with: + connector: './baton-snowflake' + account-email: 'testProvisioningUser@example.com' + account-login: 'testProvisioningUser' + account-profile: '{"first_name": "Test", "last_name": "User", "name": "testProvisioningUser", "email": "testProvisioningUser@example.com"}' + account-type: 'user' + search-method: 'email' + env: + BATON_PRIVATE_KEY_PATH: /tmp/snowflake_key.p8 diff --git a/pkg/connector/account_roles.go b/pkg/connector/account_roles.go index a5b6c264..61ce2865 100644 --- a/pkg/connector/account_roles.go +++ b/pkg/connector/account_roles.go @@ -46,7 +46,7 @@ func (o *accountRoleBuilder) List(ctx context.Context, parentResourceID *v2.Reso return nil, nil, wrapError(err, "failed to get next page offset") } - accountRoles, _, err := o.client.ListAccountRoles(ctx, cursor, resourcePageSize) + accountRoles, err := o.client.ListAccountRoles(ctx, cursor, resourcePageSize) if err != nil { return nil, nil, wrapError(err, "failed to list account roles") } @@ -95,7 +95,7 @@ func (o *accountRoleBuilder) Entitlements(_ context.Context, resource *v2.Resour } func (o *accountRoleBuilder) Grants(ctx context.Context, resource *v2.Resource, _ rs.SyncOpAttrs) ([]*v2.Grant, *rs.SyncOpResults, error) { - accountRoleGrantees, _, err := o.client.ListAccountRoleGrantees(ctx, resource.DisplayName) + accountRoleGrantees, err := o.client.ListAccountRoleGrantees(ctx, resource.DisplayName) if err != nil { return nil, nil, wrapError(err, "failed to list account role grantees") } @@ -138,7 +138,7 @@ func (o *accountRoleBuilder) Grant(ctx context.Context, principal *v2.Resource, return nil, err } - _, err := o.client.GrantAccountRole(ctx, entitlement.Resource.Id.Resource, principal.Id.Resource) + err := o.client.GrantAccountRole(ctx, entitlement.Resource.Id.Resource, principal.Id.Resource) if err != nil { err = wrapError(err, "failed to grant account role") @@ -167,7 +167,7 @@ func (o *accountRoleBuilder) Revoke(ctx context.Context, grant *v2.Grant) (annot return nil, err } - _, err := o.client.RevokeAccountRole(ctx, grant.Entitlement.Resource.Id.Resource, grant.Principal.Id.Resource) + err := o.client.RevokeAccountRole(ctx, grant.Entitlement.Resource.Id.Resource, grant.Principal.Id.Resource) if err != nil { err = wrapError(err, "failed to revoke account role") diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index b95c7c9b..4d23922d 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -49,13 +49,136 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) return &v2.ConnectorMetadata{ DisplayName: "Baton Snowflake", Description: "Connector syncing users, databases, tables, and account roles from Snowflake.", + AccountCreationSchema: &v2.ConnectorAccountCreationSchema{ + FieldMap: map[string]*v2.ConnectorAccountCreationSchema_Field{ + "name": { + DisplayName: "User Name", + Required: true, + Description: "The name of the user (required - case-sensitive)", + Placeholder: "username", + Order: 0, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "login": { + DisplayName: "Login Name", + Required: false, + Description: "The login name for the user (defaults to email if not provided)", + Placeholder: "user@example.com", + Order: 1, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "display_name": { + DisplayName: "Display Name", + Required: false, + Description: "The display name for the user", + Placeholder: "John Doe", + Order: 2, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "first_name": { + DisplayName: "First Name", + Required: false, + Description: "The first name of the user", + Placeholder: "John", + Order: 3, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "last_name": { + DisplayName: "Last Name", + Required: false, + Description: "The last name of the user", + Placeholder: "Doe", + Order: 4, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "email": { + DisplayName: "Email", + Required: false, + Description: "The email address for the user", + Placeholder: "user@example.com", + Order: 5, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "comment": { + DisplayName: "Comment", + Required: false, + Description: "A comment or description for the user", + Placeholder: "User description", + Order: 6, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "disabled": { + DisplayName: "Disabled", + Required: false, + Description: "Whether the user account should be disabled", + Order: 7, + Field: &v2.ConnectorAccountCreationSchema_Field_BoolField{ + BoolField: &v2.ConnectorAccountCreationSchema_BoolField{}, + }, + }, + "default_warehouse": { + DisplayName: "Default Warehouse", + Required: false, + Description: "The default warehouse to use when this user starts a session", + Placeholder: "COMPUTE_WH", + Order: 8, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "default_namespace": { + DisplayName: "Default Namespace", + Required: false, + Description: "The default namespace to use when this user starts a session", + Placeholder: "DATABASE.SCHEMA", + Order: 9, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "default_role": { + DisplayName: "Default Role", + Required: false, + Description: "The default role to use when this user starts a session", + Placeholder: "PUBLIC", + Order: 10, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "default_secondary_roles": { + DisplayName: "Default Secondary Roles", + Required: false, + Description: "The default secondary roles of this user to use when starting a session. Valid values: ALL or NONE. Default is ALL.", + Placeholder: "ALL", + Order: 11, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + }, + }, }, nil } // Validate is called to ensure that the connector is properly configured. It should exercise any API credentials // to be sure that they are valid. func (d *Connector) Validate(ctx context.Context) (annotations.Annotations, error) { - users, _, err := d.Client.ListUsers(ctx, "", 1) + users, err := d.Client.ListUsers(ctx, "", 1) if err != nil { return nil, err } diff --git a/pkg/connector/databases.go b/pkg/connector/databases.go index b7f60920..ec5d2d6f 100644 --- a/pkg/connector/databases.go +++ b/pkg/connector/databases.go @@ -61,7 +61,7 @@ func (o *databaseBuilder) List(ctx context.Context, parentResourceID *v2.Resourc return nil, nil, wrapError(err, "failed to get next page offset") } - databases, _, err := o.client.ListDatabases(ctx, cursor, resourcePageSize) + databases, err := o.client.ListDatabases(ctx, cursor, resourcePageSize) if err != nil { return nil, nil, wrapError(err, "failed to list databases") } @@ -111,9 +111,9 @@ func (o *databaseBuilder) Grants(ctx context.Context, resource *v2.Resource, _ r return nil, nil, nil } - owner, ownerResp, err := o.client.GetAccountRole(ctx, database.Owner) + owner, ownerStatusCode, err := o.client.GetAccountRole(ctx, database.Owner) if err != nil { - if snowflake.IsUnprocessableEntity(ownerResp, err) { + if snowflake.IsUnprocessableEntity(ownerStatusCode, err) { wrappedErr := fmt.Errorf("baton-snowflake: insufficient privileges for database owner role %q (database %q): %w", database.Owner, resource.Id.Resource, err) return nil, nil, status.Error(codes.PermissionDenied, wrappedErr.Error()) } diff --git a/pkg/connector/helpers.go b/pkg/connector/helpers.go index 77f2b4d2..52f5fa50 100644 --- a/pkg/connector/helpers.go +++ b/pkg/connector/helpers.go @@ -1,9 +1,22 @@ package connector -import "fmt" +import ( + "fmt" + "strings" +) func wrapError(err error, message string) error { return fmt.Errorf("snowflake-connector: %s: %w", message, err) } const resourcePageSize = 50 + +// quoteSnowflakeIdentifier properly escapes and quotes a Snowflake identifier. +// In Snowflake, double quotes inside identifiers must be escaped by doubling them. +// Example: o"donnel becomes "o""donnel". +func quoteSnowflakeIdentifier(identifier string) string { + // Escape double quotes by doubling them + escaped := strings.ReplaceAll(identifier, `"`, `""`) + // Wrap in double quotes + return fmt.Sprintf(`"%s"`, escaped) +} diff --git a/pkg/connector/tables.go b/pkg/connector/tables.go index 95333208..e1b39e5d 100644 --- a/pkg/connector/tables.go +++ b/pkg/connector/tables.go @@ -69,8 +69,8 @@ func (o *tableBuilder) isDBSharedOrSystem(ctx context.Context, resource *v2.Reso return val == "true" || val == "1", nil } } - db, resp, err := o.client.GetDatabase(ctx, databaseName) - if snowflake.IsUnprocessableEntity(resp, err) { + db, statusCode, err := o.client.GetDatabase(ctx, databaseName) + if snowflake.IsUnprocessableEntity(statusCode, err) { return true, nil } if err != nil { @@ -148,7 +148,7 @@ func (o *tableBuilder) List(ctx context.Context, parentResourceID *v2.ResourceId } const accountPageSize = 200 - tables, nextCursor, _, err := o.client.ListTablesInAccount(ctx, cursor, accountPageSize) + tables, nextCursor, err := o.client.ListTablesInAccount(ctx, cursor, accountPageSize) if err != nil { return nil, nil, wrapError(err, "failed to list tables in account") } @@ -280,9 +280,9 @@ func (o *tableBuilder) Grants(ctx context.Context, resource *v2.Resource, opts r switch tg.GrantedTo { case grantedToRole: - role, resp, err := o.client.GetAccountRole(ctx, tg.GranteeName) + role, statusCode, err := o.client.GetAccountRole(ctx, tg.GranteeName) if err != nil { - if snowflake.IsUnprocessableEntity(resp, err) { + if snowflake.IsUnprocessableEntity(statusCode, err) { principalId, idErr := rs.NewResourceID(accountRoleResourceType, tg.GranteeName) if idErr != nil { continue @@ -335,14 +335,14 @@ func (o *tableBuilder) Grants(ctx context.Context, resource *v2.Resource, opts r } if ownerPrincipalID == nil { - table, _, err := o.client.GetTable(ctx, databaseName, schemaName, tableName) + table, err := o.client.GetTable(ctx, databaseName, schemaName, tableName) if err != nil { return nil, nil, wrapError(err, "failed to get table for owner fallback") } if table != nil && table.Owner != "" && table.Owner != "SNOWFLAKE" { - owner, ownerResp, err := o.client.GetAccountRole(ctx, table.Owner) + owner, ownerStatusCode, err := o.client.GetAccountRole(ctx, table.Owner) switch { - case snowflake.IsUnprocessableEntity(ownerResp, err): + case snowflake.IsUnprocessableEntity(ownerStatusCode, err): // system role, skip case err != nil: return nil, nil, wrapError(err, fmt.Sprintf("failed to get account role for table owner %q", table.Owner)) diff --git a/pkg/connector/users.go b/pkg/connector/users.go index b546ce1a..efadd082 100644 --- a/pkg/connector/users.go +++ b/pkg/connector/users.go @@ -2,10 +2,22 @@ package connector import ( "context" + "errors" + "fmt" + "net/http" + "strings" + "time" v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" + "github.com/conductorone/baton-sdk/pkg/annotations" + connectorbuilder "github.com/conductorone/baton-sdk/pkg/connectorbuilder" + "github.com/conductorone/baton-sdk/pkg/crypto" rs "github.com/conductorone/baton-sdk/pkg/types/resource" "github.com/conductorone/baton-snowflake/pkg/snowflake" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type userBuilder struct { @@ -99,6 +111,52 @@ func getUserDetailedStatus(user *snowflake.User) string { return "" } +// extractProfileFields extracts optional fields from the accountInfo profile and populates the createReq. +func extractProfileFields(accountInfo *v2.AccountInfo, createReq *snowflake.CreateUserRequest) { + profile := accountInfo.GetProfile() + if profile == nil { + return + } + + pMap := profile.AsMap() + + if loginNameStr, ok := pMap["login"].(string); ok && loginNameStr != "" { + createReq.LoginName = loginNameStr + } + if displayNameStr, ok := pMap["display_name"].(string); ok && displayNameStr != "" { + createReq.DisplayName = displayNameStr + } + if firstNameStr, ok := pMap["first_name"].(string); ok && firstNameStr != "" { + createReq.FirstName = firstNameStr + } + if lastNameStr, ok := pMap["last_name"].(string); ok && lastNameStr != "" { + createReq.LastName = lastNameStr + } + if emailStr, ok := pMap["email"].(string); ok && emailStr != "" { + createReq.Email = emailStr + } + if commentStr, ok := pMap["comment"].(string); ok && commentStr != "" { + createReq.Comment = commentStr + } + // Handle disabled as boolean + if disabledVal, ok := pMap["disabled"].(bool); ok { + createReq.Disabled = disabledVal + } + // Default warehouse, namespace, role, and secondary roles + if defaultWarehouseStr, ok := pMap["default_warehouse"].(string); ok && defaultWarehouseStr != "" { + createReq.DefaultWarehouse = defaultWarehouseStr + } + if defaultNamespaceStr, ok := pMap["default_namespace"].(string); ok && defaultNamespaceStr != "" { + createReq.DefaultNamespace = defaultNamespaceStr + } + if defaultRoleStr, ok := pMap["default_role"].(string); ok && defaultRoleStr != "" { + createReq.DefaultRole = defaultRoleStr + } + if defaultSecondaryRolesStr, ok := pMap["default_secondary_roles"].(string); ok && defaultSecondaryRolesStr != "" { + createReq.DefaultSecondaryRoles = defaultSecondaryRolesStr + } +} + // List returns all the users from the database as resource objects. // Users include a UserTrait because they are the 'shape' of a standard user. func (o *userBuilder) List(ctx context.Context, parentResourceID *v2.ResourceId, opts rs.SyncOpAttrs) ([]*v2.Resource, *rs.SyncOpResults, error) { @@ -107,7 +165,7 @@ func (o *userBuilder) List(ctx context.Context, parentResourceID *v2.ResourceId, return nil, nil, wrapError(err, "failed to get next page cursor") } - users, _, err := o.client.ListUsers(ctx, cursor, resourcePageSize) + users, err := o.client.ListUsers(ctx, cursor, resourcePageSize) if err != nil { return nil, nil, wrapError(err, "failed to list users") } @@ -148,6 +206,216 @@ func (o *userBuilder) Grants(ctx context.Context, resource *v2.Resource, _ rs.Sy return nil, nil, nil } +// CreateAccountCapabilityDetails returns the capability details for user account provisioning. +func (o *userBuilder) CreateAccountCapabilityDetails(ctx context.Context) (*v2.CredentialDetailsAccountProvisioning, annotations.Annotations, error) { + return &v2.CredentialDetailsAccountProvisioning{ + SupportedCredentialOptions: []v2.CapabilityDetailCredentialOption{ + v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_RANDOM_PASSWORD, + v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_ENCRYPTED_PASSWORD, + }, + PreferredCredentialOption: v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_RANDOM_PASSWORD, + }, nil, nil +} + +// CreateAccount creates a new Snowflake user using the REST API. +func (o *userBuilder) CreateAccount( + ctx context.Context, + accountInfo *v2.AccountInfo, + credentialOptions *v2.LocalCredentialOptions, +) (connectorbuilder.CreateAccountResponse, []*v2.PlaintextData, annotations.Annotations, error) { + l := ctxzap.Extract(ctx) + + // Extract user name from accountInfo + // The user name must be provided in profile.name (required in schema) + userName := "" + if profile := accountInfo.GetProfile(); profile != nil { + if nameStr, ok := rs.GetProfileStringValue(profile, "name"); ok && nameStr != "" { + userName = nameStr + } + } + + if userName == "" { + return nil, nil, nil, status.Error(codes.InvalidArgument, "baton-snowflake: user name is required (provide via profile.name)") + } + + // Build create user request + // name is the only required field for the create user request + // Quote the username to preserve case sensitivity (Snowflake stores unquoted identifiers in uppercase) + // Escape any double quotes in the username by doubling them + quotedUserName := quoteSnowflakeIdentifier(userName) + createReq := &snowflake.CreateUserRequest{ + Name: quotedUserName, + } + + // Extract optional fields from profile (login and email are optional - only set if provided in profile) + extractProfileFields(accountInfo, createReq) + + // Handle password generation + var plaintextData []*v2.PlaintextData + if credentialOptions != nil { + createReq.MustChangePassword = credentialOptions.GetForceChangeAtNextLogin() + // Generate password if random password is requested + if credentialOptions.GetRandomPassword() == nil && credentialOptions.GetPlaintextPassword() == nil { + return nil, nil, nil, errors.New("unsupported credential option") + } + plaintextPassword, err := crypto.GeneratePassword(ctx, credentialOptions) + if err != nil { + return nil, nil, nil, wrapError(err, "failed to generate password") + } + createReq.Password = plaintextPassword + plaintextData = append(plaintextData, &v2.PlaintextData{ + Name: "password", + Description: "Password for the user", + Bytes: []byte(plaintextPassword), + }) + } + + // Create user via REST API + _, rateLimitDesc, err := o.client.CreateUserREST(ctx, createReq) + if err != nil { + l.Error("failed to create user", + zap.String("user_name", userName), + zap.Error(err), + ) + var annos annotations.Annotations + if rateLimitDesc != nil { + annos = annotations.New(rateLimitDesc) + } + return nil, nil, annos, wrapError(err, "failed to create user") + } + + user, err := o.fetchUserWithSQLRetry(ctx, userName) + if err != nil { + l.Error("failed to fetch user after creation", + zap.String("user_name", userName), + zap.Error(err), + ) + annos := annotations.Annotations{} + if rateLimitDesc != nil { + annos.Update(rateLimitDesc) + } + return nil, nil, annos, wrapError(err, "failed to fetch user after creation") + } + + // Build resource for the new user + resource, err := userResource(ctx, user, o.syncSecrets) + if err != nil { + return nil, nil, nil, wrapError(err, "failed to create user resource") + } + + l.Debug("user created successfully", + zap.String("user_name", user.Username), + ) + + // Build annotations with rate limit information + var annos annotations.Annotations + if rateLimitDesc != nil { + annos = annotations.New(rateLimitDesc) + } + + // Return success result with plaintext data (password) + result := &v2.CreateAccountResponse_SuccessResult{ + Resource: resource, + } + + return result, plaintextData, annos, nil +} + +// fetchUserWithSQLRetry attempts to fetch a user using the SQL API with retry logic for 422 errors. +// Retries up to 5 times with exponential backoff if we get a 422 Unprocessable Entity error. +func (o *userBuilder) fetchUserWithSQLRetry(ctx context.Context, userName string) (*snowflake.User, error) { + l := ctxzap.Extract(ctx) + maxRetries := 5 + baseDelay := 500 * time.Millisecond + + for attempt := 0; attempt <= maxRetries; attempt++ { + user, statusCode, err := o.client.GetUser(ctx, userName) + if err == nil && statusCode == http.StatusOK { + l.Debug("user fetched successfully via SQL API", + zap.String("user_name", userName), + ) + return user, nil + } + + // Check if we got a 422 error + is422 := false + if statusCode == http.StatusUnprocessableEntity { + is422 = true + } else if err != nil { + errStr := strings.ToLower(err.Error()) + is422 = strings.Contains(errStr, "422") || strings.Contains(errStr, "unprocessable entity") + } + + // If it's not a 422 error, or we've exhausted retries, return the error + if !is422 || attempt >= maxRetries { + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("baton-snowflake: unexpected status code %d when fetching user", statusCode) + } + return nil, fmt.Errorf("baton-snowflake: failed to fetch user") + } + + // Calculate exponential backoff: baseDelay * 2^attempt + delay := baseDelay + for i := 0; i < attempt; i++ { + delay *= 2 + } + + l.Debug("user fetch returned 422, retrying with SQL API", + zap.String("user_name", userName), + zap.Int("attempt", attempt+1), + zap.Int("max_retries", maxRetries), + zap.Duration("delay", delay), + ) + + // Wait before retrying + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + // Continue to next iteration + } + } + + // This should not be reached, but handle it just in case + return nil, fmt.Errorf("baton-snowflake: failed to fetch user after %d retries", maxRetries) +} + +// Delete deletes a Snowflake user using the REST API. +func (o *userBuilder) Delete(ctx context.Context, resourceId *v2.ResourceId, parentResourceID *v2.ResourceId) (annotations.Annotations, error) { + l := ctxzap.Extract(ctx) + + userName := resourceId.Resource + if userName == "" { + return nil, fmt.Errorf("baton-snowflake: user name is required") + } + + // Quote the username to match the case-sensitive identifier created with quotes + // This ensures we delete the exact case-sensitive identifier + // Escape any double quotes in the username by doubling them + quotedUserName := quoteSnowflakeIdentifier(userName) + options := &snowflake.DeleteUserOptions{ + IfExists: true, + } + // Delete user via REST API + _, err := o.client.DeleteUserREST(ctx, quotedUserName, options) + if err != nil { + l.Error("failed to delete user", + zap.String("user_name", userName), + zap.Error(err), + ) + return nil, wrapError(err, "failed to delete user") + } + + l.Debug("user deleted successfully", + zap.String("user_name", userName), + ) + + return nil, nil +} + func newUserBuilder(client *snowflake.Client, syncSecrets bool) *userBuilder { return &userBuilder{ resourceType: userResourceType, diff --git a/pkg/snowflake/account_role.go b/pkg/snowflake/account_role.go index b0076bc3..4ded711f 100644 --- a/pkg/snowflake/account_role.go +++ b/pkg/snowflake/account_role.go @@ -3,7 +3,6 @@ package snowflake import ( "context" "fmt" - "net/http" "github.com/conductorone/baton-sdk/pkg/uhttp" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" @@ -65,7 +64,7 @@ func (r *ListAccountRoleGranteesRawResponse) GetAccountRoleGrantees() []AccountR return accountRoleGrantees } -func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) ([]AccountRole, *http.Response, error) { +func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) ([]AccountRole, error) { var queries []string if cursor != "" { @@ -76,13 +75,14 @@ func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListAccountRolesRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { - return nil, nil, err + return nil, err } l := ctxzap.Extract(ctx) @@ -90,99 +90,121 @@ func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, resp, err + return nil, err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { - return nil, resp, err + return nil, err } accountRoles, err := response.GetAccountRoles() if err != nil { - return nil, resp, err + return nil, err } - return accountRoles, resp, nil + return accountRoles, nil } -func (c *Client) ListAccountRoleGrantees(ctx context.Context, roleName string) ([]AccountRoleGrantee, *http.Response, error) { +func (c *Client) ListAccountRoleGrantees(ctx context.Context, roleName string) ([]AccountRoleGrantee, error) { queries := []string{ fmt.Sprintf("SHOW GRANTS OF ROLE \"%s\";", roleName), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListAccountRoleGranteesRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { - return nil, nil, err + return nil, err } req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, resp, err + return nil, err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { - return nil, resp, err + return nil, err } - return response.GetAccountRoleGrantees(), resp, nil + accountRoleGrantees := response.GetAccountRoleGrantees() + + return accountRoleGrantees, nil } -func (c *Client) GetAccountRole(ctx context.Context, roleName string) (*AccountRole, *http.Response, error) { +func (c *Client) GetAccountRole(ctx context.Context, roleName string) (*AccountRole, int, error) { queries := []string{ fmt.Sprintf("SHOW ROLES LIKE '%s' LIMIT 1;", roleName), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, 0, err } var response ListAccountRolesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { - return nil, nil, err + statusCode := 0 + if resp != nil { + statusCode = resp.StatusCode + } + return nil, statusCode, err } accountRoles, err := response.GetAccountRoles() if err != nil { - return nil, resp, err + return nil, resp.StatusCode, err } if len(accountRoles) == 0 { - return nil, resp, nil + return nil, resp.StatusCode, nil } - return &accountRoles[0], resp, nil + return &accountRoles[0], resp.StatusCode, nil } -func (c *Client) GrantAccountRole(ctx context.Context, roleName, userName string) (*http.Response, error) { +func (c *Client) GrantAccountRole(ctx context.Context, roleName, userName string) error { queries := []string{ fmt.Sprintf("GRANT ROLE \"%s\" TO USER \"%s\";", roleName, userName), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, err + return err } - return c.Do(req) + resp, err := c.Do(req) + defer closeResponseBody(resp) + if err != nil { + return err + } + + return nil } -func (c *Client) RevokeAccountRole(ctx context.Context, roleName, userName string) (*http.Response, error) { +func (c *Client) RevokeAccountRole(ctx context.Context, roleName, userName string) error { queries := []string{ fmt.Sprintf("REVOKE ROLE \"%s\" FROM USER \"%s\";", roleName, userName), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, err + return err + } + + resp, err := c.Do(req) + defer closeResponseBody(resp) + if err != nil { + return err } - return c.Do(req) + return nil } diff --git a/pkg/snowflake/client.go b/pkg/snowflake/client.go index 50aa219c..f34a0fb6 100644 --- a/pkg/snowflake/client.go +++ b/pkg/snowflake/client.go @@ -3,6 +3,7 @@ package snowflake import ( "context" "fmt" + "io" "net/http" "net/url" "reflect" @@ -16,6 +17,8 @@ import ( const ( AuthTypeHeaderKey = "X-Snowflake-Authorization-Token-Type" AuthTypeHeaderValue = "KEYPAIR_JWT" + RoleHeaderKey = "X-Snowflake-Role" + UserAdminRole = "USERADMIN" ) const ( @@ -240,3 +243,12 @@ func Contains[T comparable](ts []T, val T) bool { } return false } + +// closeResponseBody drains and closes the response body if it exists. +// This ensures proper resource cleanup and allows connection reuse. +func closeResponseBody(resp *http.Response) { + if resp != nil && resp.Body != nil { + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + } +} diff --git a/pkg/snowflake/database.go b/pkg/snowflake/database.go index b99b0887..46356e55 100644 --- a/pkg/snowflake/database.go +++ b/pkg/snowflake/database.go @@ -3,7 +3,6 @@ package snowflake import ( "context" "fmt" - "net/http" "strings" "github.com/conductorone/baton-sdk/pkg/uhttp" @@ -60,7 +59,7 @@ func (r *ListDatabasesRawResponse) GetDatabases() ([]Database, error) { return databases, nil } -func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([]Database, *http.Response, error) { +func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([]Database, error) { var queries []string if cursor != "" { @@ -71,13 +70,14 @@ func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([ req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListDatabasesRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { - return nil, nil, err + return nil, err } l := ctxzap.Extract(ctx) @@ -85,49 +85,53 @@ func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([ req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, resp, err + return nil, err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { - return nil, resp, err + return nil, err } dbs, err := response.GetDatabases() if err != nil { - return nil, resp, err + return nil, err } - return dbs, resp, nil + return dbs, nil } -func (c *Client) GetDatabase(ctx context.Context, name string) (*Database, *http.Response, error) { +func (c *Client) GetDatabase(ctx context.Context, name string) (*Database, int, error) { queries := []string{ fmt.Sprintf("SHOW DATABASES LIKE '%s' LIMIT 1;", name), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, 0, err } var response ListDatabasesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { - if IsUnprocessableEntity(resp, err) { - return nil, resp, nil + statusCode := 0 + if resp != nil { + statusCode = resp.StatusCode } - return nil, resp, err + return nil, statusCode, err } databases, err := response.GetDatabases() if err != nil { - return nil, resp, err + return nil, resp.StatusCode, err } + if len(databases) == 0 { - return nil, resp, fmt.Errorf("database with name %s not found", name) + return nil, resp.StatusCode, fmt.Errorf("database with name %s not found", name) } else if len(databases) > 1 { - return nil, resp, fmt.Errorf("expected 1 database with name %s, got %d", name, len(databases)) + return nil, resp.StatusCode, fmt.Errorf("expected 1 database with name %s, got %d", name, len(databases)) } - return &databases[0], resp, nil + return &databases[0], resp.StatusCode, nil } diff --git a/pkg/snowflake/helper.go b/pkg/snowflake/helper.go index 34f95339..5aa2253b 100644 --- a/pkg/snowflake/helper.go +++ b/pkg/snowflake/helper.go @@ -8,8 +8,8 @@ import ( // IsUnprocessableEntity reports whether the Snowflake API returned HTTP 422 (Unprocessable Entity). // Snowflake returns 422 for certain operations on system/predefined objects (e.g. SHOW GRANTS OF ROLE for ACCOUNTADMIN, // SHOW ROLES LIKE for some roles). Callers can treat this as "no data" or "not resolvable" instead of a hard error. -func IsUnprocessableEntity(resp *http.Response, err error) bool { - if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { +func IsUnprocessableEntity(statusCode int, err error) bool { + if statusCode == http.StatusUnprocessableEntity { return true } if err != nil && (strings.Contains(err.Error(), "422") || strings.Contains(err.Error(), "Unprocessable Entity")) { diff --git a/pkg/snowflake/secrets.go b/pkg/snowflake/secrets.go index fd5aa9bc..68ea873e 100644 --- a/pkg/snowflake/secrets.go +++ b/pkg/snowflake/secrets.go @@ -29,6 +29,7 @@ func (c *Client) ListSecrets(ctx context.Context, database string) ([]Secret, er var response ListSecretsRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { var errMsg struct { @@ -55,7 +56,6 @@ func (c *Client) ListSecrets(ctx context.Context, database string) ([]Secret, er return nil, err } - defer resp.Body.Close() secrets, err := response.ListSecrets() if err != nil { @@ -77,10 +77,10 @@ func (c *Client) UserRsa(ctx context.Context, username string) (*UserRsa, error) var response RsaGetUserRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { return nil, err } - defer resp.Body.Close() secrets, err := response.GetUserRsa(ctx) if err != nil { diff --git a/pkg/snowflake/table.go b/pkg/snowflake/table.go index 7bef8f19..ebeba9b0 100644 --- a/pkg/snowflake/table.go +++ b/pkg/snowflake/table.go @@ -61,7 +61,7 @@ func (r *ListTablesRawResponse) ListTables() ([]Table, error) { const tableListCursorSep = "\x00" -func (c *Client) ListTablesInAccount(ctx context.Context, cursor string, limit int) ([]Table, string, *http.Response, error) { +func (c *Client) ListTablesInAccount(ctx context.Context, cursor string, limit int) ([]Table, string, error) { l := ctxzap.Extract(ctx) var q string @@ -81,43 +81,39 @@ func (c *Client) ListTablesInAccount(ctx context.Context, cursor string, limit i req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, "", nil, err + return nil, "", err } var response ListTablesRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { - if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { + if resp1 != nil && resp1.StatusCode == http.StatusUnprocessableEntity { l.Debug("Insufficient privileges for SHOW TABLES IN ACCOUNT") wrappedErr := fmt.Errorf("baton-snowflake: insufficient privileges for SHOW TABLES IN ACCOUNT: %w", err) - return nil, "", nil, status.Error(codes.PermissionDenied, wrappedErr.Error()) + return nil, "", status.Error(codes.PermissionDenied, wrappedErr.Error()) } - return nil, "", nil, err - } - if resp != nil { - defer resp.Body.Close() + return nil, "", err } req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, "", resp, err + return nil, "", err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { - if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { + if resp2 != nil && resp2.StatusCode == http.StatusUnprocessableEntity { l.Debug("Insufficient privileges for SHOW TABLES IN ACCOUNT (statement result)") wrappedErr := fmt.Errorf("baton-snowflake: insufficient privileges for SHOW TABLES IN ACCOUNT (statement result): %w", err) - return nil, "", nil, status.Error(codes.PermissionDenied, wrappedErr.Error()) + return nil, "", status.Error(codes.PermissionDenied, wrappedErr.Error()) } - return nil, "", resp, err - } - if resp != nil { - defer resp.Body.Close() + return nil, "", err } tables, err := response.ListTables() if err != nil { - return nil, "", resp, err + return nil, "", err } var nextCursor string @@ -125,7 +121,7 @@ func (c *Client) ListTablesInAccount(ctx context.Context, cursor string, limit i last := tables[len(tables)-1] nextCursor = last.DatabaseName + tableListCursorSep + last.SchemaName + tableListCursorSep + last.Name } - return tables, nextCursor, resp, nil + return tables, nextCursor, nil } // escapeSingleQuote doubles single quotes for use inside SQL string literals. @@ -149,7 +145,7 @@ func escapeDoubleQuotedIdentifier(s string) string { return strings.ReplaceAll(s, `"`, `""`) } -func (c *Client) GetTable(ctx context.Context, database, schema, tableName string) (*Table, *http.Response, error) { +func (c *Client) GetTable(ctx context.Context, database, schema, tableName string) (*Table, error) { likePattern := escapeLikePattern(tableName) queries := []string{ fmt.Sprintf("SHOW TABLES LIKE '%s' ESCAPE '\\' IN SCHEMA \"%s\".\"%s\" LIMIT 1;", likePattern, escapeDoubleQuotedIdentifier(database), escapeDoubleQuotedIdentifier(schema)), @@ -157,46 +153,42 @@ func (c *Client) GetTable(ctx context.Context, database, schema, tableName strin req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListTablesRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { - if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { - return nil, resp, nil + if resp1 != nil && resp1.StatusCode == http.StatusUnprocessableEntity { + return nil, nil } - return nil, nil, err - } - if resp != nil { - defer resp.Body.Close() + return nil, err } req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, resp, err + return nil, err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { - return nil, resp, err - } - if resp != nil { - defer resp.Body.Close() + return nil, err } tables, err := response.ListTables() if err != nil { - return nil, resp, err + return nil, err } // Filter by exact match (database, schema, and name) for _, table := range tables { if table.DatabaseName == database && table.SchemaName == schema && table.Name == tableName { - return &table, resp, nil + return &table, nil } } - return nil, resp, fmt.Errorf("table %s.%s.%s not found", database, schema, tableName) + return nil, fmt.Errorf("table %s.%s.%s not found", database, schema, tableName) } var tableGrantStructFieldToColumnMap = map[string]string{ diff --git a/pkg/snowflake/user.go b/pkg/snowflake/user.go index 47a8025d..adca85ed 100644 --- a/pkg/snowflake/user.go +++ b/pkg/snowflake/user.go @@ -3,7 +3,6 @@ package snowflake import ( "context" "fmt" - "net/http" "reflect" "strings" "time" @@ -31,9 +30,11 @@ var ( } // Sadly snowflake is inconsistent and returns different set of columns for DESC USER. + // These fields are ignored when parsing DESCRIBE USER output. ignoredUserStructFieldsForDescribeOperation = []string{ "HasRSAPublicKey", "HasPassword", + "LastSuccessLogin", // May not be present for newly created users } secretStructFieldToColumnMap = map[string]string{ @@ -181,7 +182,7 @@ func (r *GetUserRawResponse) GetValueByColumnName(columnName string) (string, bo return "", false } -func (c *Client) ListUsers(ctx context.Context, cursor string, limit int) ([]User, *http.Response, error) { +func (c *Client) ListUsers(ctx context.Context, cursor string, limit int) ([]User, error) { var queries []string if cursor != "" { queries = append(queries, fmt.Sprintf("SHOW USERS LIMIT %d FROM '%s';", limit, cursor)) @@ -191,54 +192,63 @@ func (c *Client) ListUsers(ctx context.Context, cursor string, limit int) ([]Use req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListUsersRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { - return nil, resp, err + return nil, err } req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, resp, err + return nil, err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { - return nil, resp, err + return nil, err } users, err := response.GetUsers() if err != nil { - return nil, resp, err + return nil, err } - return users, resp, nil + return users, nil } -func (c *Client) GetUser(ctx context.Context, username string) (*User, *http.Response, error) { +func (c *Client) GetUser(ctx context.Context, username string) (*User, int, error) { + // Escape double quotes in username by doubling them before quoting + escapedUsername := escapeDoubleQuotedIdentifier(username) queries := []string{ - fmt.Sprintf("DESCRIBE USER \"%s\";", username), + fmt.Sprintf("DESCRIBE USER \"%s\";", escapedUsername), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, 0, err } var response GetUserRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { - return nil, resp, err + statusCode := 0 + if resp != nil { + statusCode = resp.StatusCode + } + return nil, statusCode, err } user, err := response.GetUser() if err != nil { - return nil, resp, err + return nil, resp.StatusCode, err } - return user, resp, nil + return user, resp.StatusCode, nil } func (r *ListSecretsRawResponse) ListSecrets() ([]Secret, error) { diff --git a/pkg/snowflake/user_rest.go b/pkg/snowflake/user_rest.go new file mode 100644 index 00000000..b38fd25f --- /dev/null +++ b/pkg/snowflake/user_rest.go @@ -0,0 +1,210 @@ +package snowflake + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + + v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" + "github.com/conductorone/baton-sdk/pkg/uhttp" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" +) + +// CreateUserRequest represents the request body for creating a user via REST API. +type CreateUserRequest struct { + Name string `json:"name"` + LoginName string `json:"loginName,omitempty"` + DisplayName string `json:"displayName,omitempty"` + FirstName string `json:"firstName,omitempty"` + LastName string `json:"lastName,omitempty"` + Email string `json:"email,omitempty"` + Comment string `json:"comment,omitempty"` + Password string `json:"password,omitempty"` // #nosec G117: used for Snowflake API request body + MustChangePassword bool `json:"mustChangePassword,omitempty"` + Disabled bool `json:"disabled,omitempty"` + DefaultWarehouse string `json:"defaultWarehouse,omitempty"` + DefaultNamespace string `json:"defaultNamespace,omitempty"` + DefaultRole string `json:"defaultRole,omitempty"` + DefaultSecondaryRoles string `json:"defaultSecondaryRoles,omitempty"` // ALL or NONE +} + +// CreateUserResponse represents the response from creating a user. +type CreateUserResponse struct { + Status string `json:"status,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +// DeleteUserOptions represents optional parameters for deleting a user. +type DeleteUserOptions struct { + IfExists bool `json:"ifExists,omitempty"` +} + +// SnowflakeError represents an error response from Snowflake REST API. +type SnowflakeError struct { + Code string `json:"code"` + ErrMsg string `json:"message"` +} + +// Message implements the ErrorResponse interface. +func (e *SnowflakeError) Message() string { + if e.ErrMsg != "" { + return e.ErrMsg + } + if e.Code != "" { + return e.Code + } + return "unknown error" +} + +// createUsersApiUrl creates the URL for the users REST API endpoint. +func createUsersApiUrl(accountUrl string) (*url.URL, error) { + stringUrl, err := url.JoinPath(accountUrl, "api/v2/users") + if err != nil { + return nil, err + } + + return url.Parse(stringUrl) +} + +// createUserApiUrl creates the URL for a specific user REST API endpoint. +// The userName should be the actual identifier (case-sensitive if created with quotes). +// url.JoinPath will properly encode the path segment. +func createUserApiUrl(accountUrl string, userName string) (*url.URL, error) { + stringUrl, err := url.JoinPath(accountUrl, "api/v2/users", userName) + if err != nil { + return nil, err + } + + return url.Parse(stringUrl) +} + +// doRequest is a helper method that wraps HTTP request logic for REST API calls. +// Returns the response headers, rate limit description, status code, and error. +func (c *Client) doRequest( + ctx context.Context, + method string, + endpoint *url.URL, + target interface{}, + body interface{}, + opts ...uhttp.RequestOption, +) (*http.Header, *v2.RateLimitDescription, int, error) { + l := ctxzap.Extract(ctx) + var requestOptions []uhttp.RequestOption + requestOptions = append(requestOptions, + uhttp.WithAcceptJSONHeader(), + uhttp.WithHeader(AuthTypeHeaderKey, AuthTypeHeaderValue)) + + // Append any additional options passed in + requestOptions = append(requestOptions, opts...) + + if body != nil { + requestOptions = append(requestOptions, uhttp.WithContentTypeJSONHeader(), uhttp.WithJSONBody(body)) + } + + request, err := c.NewRequest(ctx, method, endpoint, requestOptions...) + if err != nil { + return nil, nil, 0, fmt.Errorf("baton-snowflake: failed to create request: %w", err) + } + + var rateLimitData v2.RateLimitDescription + var errorResponse SnowflakeError + doOptions := []uhttp.DoOption{ + uhttp.WithRatelimitData(&rateLimitData), + uhttp.WithErrorResponse(&errorResponse), + } + if target != nil { + doOptions = append(doOptions, uhttp.WithJSONResponse(target)) + } + + response, err := c.Do(request, doOptions...) + defer func() { + if response == nil || response.Body == nil { + return + } + _, _ = io.Copy(io.Discard, response.Body) + closeErr := response.Body.Close() + if closeErr != nil { + l.Debug("baton-snowflake: warning: failed to close response body", zap.Error(closeErr)) + } + }() + if err != nil { + statusCode := 0 + if response != nil { + statusCode = response.StatusCode + } + return nil, &rateLimitData, statusCode, fmt.Errorf("baton-snowflake: request failed: %w", err) + } + + // WithErrorResponse ensures c.Do() returns an error for status >= 300, + // so if we reach here, statusCode is guaranteed to be < 300 + return &response.Header, &rateLimitData, response.StatusCode, nil +} + +// CreateUserREST creates a new Snowflake user using the REST API. +// POST /api/v2/users. +// Returns completed=true if status code is 200, false if 202 (accepted but not completed). +func (c *Client) CreateUserREST(ctx context.Context, req *CreateUserRequest) (bool, *v2.RateLimitDescription, error) { + l := ctxzap.Extract(ctx) + + usersApiUrl, err := createUsersApiUrl(c.AccountUrl) + if err != nil { + return false, nil, fmt.Errorf("baton-snowflake: failed to create users API URL: %w", err) + } + + var response CreateUserResponse + _, rateLimitDesc, statusCode, err := c.doRequest(ctx, http.MethodPost, usersApiUrl, &response, req, uhttp.WithHeader(RoleHeaderKey, UserAdminRole)) + if err != nil { + return false, rateLimitDesc, err + } + + completed := statusCode == http.StatusOK + + l.Debug("baton-snowflake: user creation request completed", + zap.String("user_name", req.Name), + zap.Int("status_code", statusCode), + zap.Bool("completed", completed), + zap.String("status", response.Status), + zap.String("message", response.Message), + ) + + return completed, rateLimitDesc, nil +} + +// DeleteUserREST deletes a Snowflake user using the REST API. +// DELETE /api/v2/users/{name}. +func (c *Client) DeleteUserREST(ctx context.Context, userName string, options *DeleteUserOptions) (*v2.RateLimitDescription, error) { + l := ctxzap.Extract(ctx) + + userApiUrl, err := createUserApiUrl(c.AccountUrl, userName) + if err != nil { + return nil, fmt.Errorf("baton-snowflake: failed to create user API URL: %w", err) + } + + // Add query parameters if options are provided + if options != nil { + query := userApiUrl.Query() + if options.IfExists { + query.Set("ifExists", "true") + } + userApiUrl.RawQuery = query.Encode() + } + + _, rateLimitDesc, _, err := c.doRequest(ctx, http.MethodDelete, userApiUrl, nil, nil, uhttp.WithHeader(RoleHeaderKey, UserAdminRole)) + if err != nil { + l.Error("baton-snowflake: failed to delete user", + zap.String("user_name", userName), + zap.Error(err), + ) + return rateLimitDesc, err + } + + l.Debug("baton-snowflake: user deleted successfully", + zap.String("user_name", userName), + ) + + return rateLimitDesc, nil +}