From 9b8efa02cec8968b28887c3f69e3b75839470ffe Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Tue, 10 Feb 2026 10:18:37 -0300 Subject: [PATCH 01/12] add user provisioning --- .github/workflows/ci.yaml | 54 ++++------ pkg/connector/connector.go | 132 +++++++++++++++++++++++++ pkg/connector/users.go | 191 ++++++++++++++++++++++++++++++++++++ pkg/snowflake/client.go | 2 + pkg/snowflake/user_rest.go | 195 +++++++++++++++++++++++++++++++++++++ 5 files changed, 537 insertions(+), 37 deletions(-) create mode 100644 pkg/snowflake/user_rest.go diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index dc035996..9f373d16 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -43,19 +43,14 @@ jobs: test: runs-on: ubuntu-latest - # Define any services needed for the test suite (or delete this section) - # services: - # postgres: - # image: postgres:16 - # ports: - # - "5432:5432" - # env: - # POSTGRES_PASSWORD: secretpassword + if: github.event_name == 'pull_request' || github.ref == 'refs/heads/main' env: BATON_LOG_LEVEL: debug - # Add any environment variables needed to run baton-snowflake - # BATON_BASE_URL: 'http://localhost:8080' - # BATON_ACCESS_TOKEN: 'secret_token' + # Snowflake connection configuration + BATON_ACCOUNT_IDENTIFIER: ${{ secrets.BATON_ACCOUNT_IDENTIFIER }} + BATON_USER_IDENTIFIER: ${{ secrets.BATON_USER_IDENTIFIER }} + BATON_ACCOUNT_URL: ${{ secrets.BATON_ACCOUNT_URL }} + BATON_PRIVATE_KEY: ${{ secrets.BATON_PRIVATE_KEY }} # The following parameters are passed to grant/revoke commands # Change these to the correct IDs for your test data CONNECTOR_GRANT: 'grant:entitlement:group:1234:member:user:9876' @@ -69,37 +64,22 @@ jobs: uses: actions/setup-go@v5 with: go-version-file: 'go.mod' - # Install any dependencies here (or delete this) - # - name: Install postgres client - # run: sudo apt install postgresql-client - # Run any fixture setup here (or delete this) - # - name: Import sql into postgres - # run: psql -h localhost --user postgres -f environment.sql - # env: - # PGPASSWORD: secretpassword + - name: Build baton-snowflake - run: go build ./cmd/baton-snowflake + run: go build -o baton-snowflake ./cmd/baton-snowflake # - name: Run baton-snowflake # run: ./baton-snowflake - name: Install baton run: ./scripts/get-baton.sh && mv baton /usr/local/bin - # - name: Check for grant before revoking - # run: - # baton grants --entitlement="${{ env.CONNECTOR_ENTITLEMENT }}" --output-format=json | jq --exit-status ".grants[].principal.id.resource == \"${{ env.CONNECTOR_PRINCIPAL }}\"" - - # - name: Revoke grants - # run: ./baton-snowflake --revoke-grant="${{ env.CONNECTOR_GRANT }}" - - # - name: Check grant was revoked - # run: ./baton-snowflake && baton grants --entitlement="${{ env.CONNECTOR_ENTITLEMENT }}" --output-format=json | jq --exit-status "if .grants then .grants[]?.principal.id.resource != \"${{ env.CONNECTOR_PRINCIPAL }}\" else . end" - - # - name: Grant entitlement - # # Change the grant arguments to the correct IDs for your test data - # run: ./baton-snowflake --grant-entitlement="${{ env.CONNECTOR_ENTITLEMENT }}" --grant-principal="${{ env.CONNECTOR_PRINCIPAL }}" --grant-principal-type="${{ env.CONNECTOR_PRINCIPAL_TYPE }}" - - # - name: Check grant was re-granted - # run: - # baton grants --entitlement="${{ env.CONNECTOR_ENTITLEMENT }}" --output-format=json | jq --exit-status ".grants[].principal.id.resource == \"${{ env.CONNECTOR_PRINCIPAL }}\"" + - name: Test Account Provisioning + uses: ConductorOne/github-workflows/.github/actions/account-provisioning@main + with: + connector: './baton-snowflake' + account-email: 'test-provisioning@example.com' + account-login: 'test-provisioning-user' + account-profile: '{"first_name": "Test", "last_name": "User", "name": "test-provisioning-user", "email": "test-provisioning@example.com"}' + account-type: 'user' + search-method: 'email' diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index b95c7c9b..071138af 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -49,6 +49,138 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) return &v2.ConnectorMetadata{ DisplayName: "Baton Snowflake", Description: "Connector syncing users, databases, tables, and account roles from Snowflake.", + AccountCreationSchema: &v2.ConnectorAccountCreationSchema{ + FieldMap: map[string]*v2.ConnectorAccountCreationSchema_Field{ + "name": { + DisplayName: "User Name", + Required: true, + Description: "The name of the user (required). Can be provided via login or profile.name", + Placeholder: "username", + Order: 0, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "login": { + DisplayName: "Login Name", + Required: false, + Description: "The login name for the user (defaults to email if not provided)", + Placeholder: "user@example.com", + Order: 1, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "display_name": { + DisplayName: "Display Name", + Required: false, + Description: "The display name for the user", + Placeholder: "John Doe", + Order: 2, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "first_name": { + DisplayName: "First Name", + Required: false, + Description: "The first name of the user", + Placeholder: "John", + Order: 3, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "last_name": { + DisplayName: "Last Name", + Required: false, + Description: "The last name of the user", + Placeholder: "Doe", + Order: 4, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "email": { + DisplayName: "Email", + Required: false, + Description: "The email address for the user", + Placeholder: "user@example.com", + Order: 5, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "comment": { + DisplayName: "Comment", + Required: false, + Description: "A comment or description for the user", + Placeholder: "User description", + Order: 6, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "disabled": { + DisplayName: "Disabled", + Required: false, + Description: "Whether the user account should be disabled", + Order: 8, + Field: &v2.ConnectorAccountCreationSchema_Field_BoolField{ + BoolField: &v2.ConnectorAccountCreationSchema_BoolField{}, + }, + }, + "must_change_password": { + DisplayName: "Must Change Password", + Required: false, + Description: "Whether the user must change their password on next login", + Order: 9, + Field: &v2.ConnectorAccountCreationSchema_Field_BoolField{ + BoolField: &v2.ConnectorAccountCreationSchema_BoolField{}, + }, + }, + "default_warehouse": { + DisplayName: "Default Warehouse", + Required: false, + Description: "The default warehouse to use when this user starts a session", + Placeholder: "COMPUTE_WH", + Order: 10, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "default_namespace": { + DisplayName: "Default Namespace", + Required: false, + Description: "The default namespace to use when this user starts a session", + Placeholder: "DATABASE.SCHEMA", + Order: 11, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "default_role": { + DisplayName: "Default Role", + Required: false, + Description: "The default role to use when this user starts a session", + Placeholder: "PUBLIC", + Order: 12, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + "default_secondary_roles": { + DisplayName: "Default Secondary Roles", + Required: false, + Description: "The default secondary roles of this user to use when starting a session. Valid values: ALL or NONE. Default is ALL.", + Placeholder: "ALL", + Order: 13, + Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ + StringField: &v2.ConnectorAccountCreationSchema_StringField{}, + }, + }, + }, + }, }, nil } diff --git a/pkg/connector/users.go b/pkg/connector/users.go index b546ce1a..aadf4221 100644 --- a/pkg/connector/users.go +++ b/pkg/connector/users.go @@ -2,10 +2,18 @@ package connector import ( "context" + "fmt" v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" + "github.com/conductorone/baton-sdk/pkg/annotations" + connectorbuilder "github.com/conductorone/baton-sdk/pkg/connectorbuilder" + "github.com/conductorone/baton-sdk/pkg/crypto" rs "github.com/conductorone/baton-sdk/pkg/types/resource" "github.com/conductorone/baton-snowflake/pkg/snowflake" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type userBuilder struct { @@ -99,6 +107,56 @@ func getUserDetailedStatus(user *snowflake.User) string { return "" } +// extractProfileFields extracts optional fields from the accountInfo profile and populates the createReq. +func extractProfileFields(accountInfo *v2.AccountInfo, createReq *snowflake.CreateUserRequest) { + profile := accountInfo.GetProfile() + if profile == nil { + return + } + + pMap := profile.AsMap() + + if loginNameStr, ok := pMap["login"].(string); ok && loginNameStr != "" { + createReq.LoginName = loginNameStr + } + if displayNameStr, ok := pMap["display_name"].(string); ok && displayNameStr != "" { + createReq.DisplayName = displayNameStr + } + if firstNameStr, ok := pMap["first_name"].(string); ok && firstNameStr != "" { + createReq.FirstName = firstNameStr + } + if lastNameStr, ok := pMap["last_name"].(string); ok && lastNameStr != "" { + createReq.LastName = lastNameStr + } + if emailStr, ok := pMap["email"].(string); ok && emailStr != "" { + createReq.Email = emailStr + } + if commentStr, ok := pMap["comment"].(string); ok && commentStr != "" { + createReq.Comment = commentStr + } + // Handle disabled as boolean + if disabledVal, ok := pMap["disabled"].(bool); ok { + createReq.Disabled = disabledVal + } + // Allow must_change_password to be overridden from profile (as boolean) + if mustChangeVal, ok := pMap["must_change_password"].(bool); ok { + createReq.MustChangePassword = mustChangeVal + } + // Default warehouse, namespace, role, and secondary roles + if defaultWarehouseStr, ok := pMap["default_warehouse"].(string); ok && defaultWarehouseStr != "" { + createReq.DefaultWarehouse = defaultWarehouseStr + } + if defaultNamespaceStr, ok := pMap["default_namespace"].(string); ok && defaultNamespaceStr != "" { + createReq.DefaultNamespace = defaultNamespaceStr + } + if defaultRoleStr, ok := pMap["default_role"].(string); ok && defaultRoleStr != "" { + createReq.DefaultRole = defaultRoleStr + } + if defaultSecondaryRolesStr, ok := pMap["default_secondary_roles"].(string); ok && defaultSecondaryRolesStr != "" { + createReq.DefaultSecondaryRoles = defaultSecondaryRolesStr + } +} + // List returns all the users from the database as resource objects. // Users include a UserTrait because they are the 'shape' of a standard user. func (o *userBuilder) List(ctx context.Context, parentResourceID *v2.ResourceId, opts rs.SyncOpAttrs) ([]*v2.Resource, *rs.SyncOpResults, error) { @@ -148,6 +206,139 @@ func (o *userBuilder) Grants(ctx context.Context, resource *v2.Resource, _ rs.Sy return nil, nil, nil } +// CreateAccountCapabilityDetails returns the capability details for user account provisioning. +func (o *userBuilder) CreateAccountCapabilityDetails(ctx context.Context) (*v2.CredentialDetailsAccountProvisioning, annotations.Annotations, error) { + return &v2.CredentialDetailsAccountProvisioning{ + SupportedCredentialOptions: []v2.CapabilityDetailCredentialOption{ + v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_RANDOM_PASSWORD, + v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_ENCRYPTED_PASSWORD, + }, + PreferredCredentialOption: v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_RANDOM_PASSWORD, + }, nil, nil +} + +// CreateAccount creates a new Snowflake user using the REST API. +func (o *userBuilder) CreateAccount( + ctx context.Context, + accountInfo *v2.AccountInfo, + credentialOptions *v2.LocalCredentialOptions, +) (connectorbuilder.CreateAccountResponse, []*v2.PlaintextData, annotations.Annotations, error) { + l := ctxzap.Extract(ctx) + + // Extract user name from accountInfo + // The user name should come from profile.name first (required in schema), then fall back to Login + userName := "" + if profile := accountInfo.GetProfile(); profile != nil { + if nameStr, ok := rs.GetProfileStringValue(profile, "name"); ok && nameStr != "" { + userName = nameStr + } + } + // Fall back to login if profile name is not available + if userName == "" { + userName = accountInfo.GetLogin() + } + + if userName == "" { + return nil, nil, nil, status.Error(codes.InvalidArgument, "baton-snowflake: user name is required (provide via profile.name or login)") + } + + // Build create user request + // name is the only required field for the create user request + createReq := &snowflake.CreateUserRequest{ + Name: userName, + } + + // Extract optional fields from profile (login and email are optional - only set if provided in profile) + extractProfileFields(accountInfo, createReq) + + // Handle password generation + var plaintextData []*v2.PlaintextData + if credentialOptions != nil { + // Generate password if random password is requested + if randomPassword := credentialOptions.GetRandomPassword(); randomPassword != nil { + password, err := crypto.GeneratePassword(ctx, credentialOptions) + if err != nil { + return nil, nil, nil, wrapError(err, "failed to generate random password") + } + createReq.Password = password + + // Return the plaintext password so it can be encrypted and returned to the caller + plaintextData = append(plaintextData, v2.PlaintextData_builder{ + Name: "password", + Description: "Generated password for Snowflake user", + Bytes: []byte(password), + }.Build()) + } else if plaintextPassword := credentialOptions.GetPlaintextPassword(); plaintextPassword != nil { + // Use provided plaintext password + createReq.Password = plaintextPassword.GetPlaintextPassword() + } + } + + // Create user via REST API + user, rateLimitDesc, err := o.client.CreateUserREST(ctx, createReq) + if err != nil { + l.Error("failed to create user", + zap.String("user_name", userName), + zap.Error(err), + ) + var annos annotations.Annotations + if rateLimitDesc != nil { + annos = annotations.New(rateLimitDesc) + } + return nil, nil, annos, wrapError(err, "failed to create user") + } + + // Build resource for the new user + resource, err := userResource(ctx, user, o.syncSecrets) + if err != nil { + return nil, nil, nil, wrapError(err, "failed to create user resource") + } + + l.Debug("user created successfully", + zap.String("user_name", user.Username), + ) + + // Build annotations with rate limit information + var annos annotations.Annotations + if rateLimitDesc != nil { + annos = annotations.New(rateLimitDesc) + } + + // Return success result with plaintext data (password) + result := &v2.CreateAccountResponse_SuccessResult{ + Resource: resource, + } + + return result, plaintextData, annos, nil +} + +// Delete deletes a Snowflake user using the REST API. +func (o *userBuilder) Delete(ctx context.Context, resourceId *v2.ResourceId, parentResourceID *v2.ResourceId) (annotations.Annotations, error) { + l := ctxzap.Extract(ctx) + + userName := resourceId.Resource + if userName == "" { + return nil, fmt.Errorf("baton-snowflake: user name is required") + } + + // Delete user via REST API + // Using default options (ifExists=false) + _, err := o.client.DeleteUserREST(ctx, userName, nil) + if err != nil { + l.Error("failed to delete user", + zap.String("user_name", userName), + zap.Error(err), + ) + return nil, wrapError(err, "failed to delete user") + } + + l.Debug("user deleted successfully", + zap.String("user_name", userName), + ) + + return nil, nil +} + func newUserBuilder(client *snowflake.Client, syncSecrets bool) *userBuilder { return &userBuilder{ resourceType: userResourceType, diff --git a/pkg/snowflake/client.go b/pkg/snowflake/client.go index 50aa219c..7feb76d8 100644 --- a/pkg/snowflake/client.go +++ b/pkg/snowflake/client.go @@ -16,6 +16,8 @@ import ( const ( AuthTypeHeaderKey = "X-Snowflake-Authorization-Token-Type" AuthTypeHeaderValue = "KEYPAIR_JWT" + RoleHeaderKey = "X-Snowflake-Role" + UserAdminRole = "USERADMIN" ) const ( diff --git a/pkg/snowflake/user_rest.go b/pkg/snowflake/user_rest.go new file mode 100644 index 00000000..0353417e --- /dev/null +++ b/pkg/snowflake/user_rest.go @@ -0,0 +1,195 @@ +package snowflake + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "net/url" + + v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" + "github.com/conductorone/baton-sdk/pkg/uhttp" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" +) + +// CreateUserRequest represents the request body for creating a user via REST API. +type CreateUserRequest struct { + Name string `json:"name"` + LoginName string `json:"loginName,omitempty"` + DisplayName string `json:"displayName,omitempty"` + FirstName string `json:"firstName,omitempty"` + LastName string `json:"lastName,omitempty"` + Email string `json:"email,omitempty"` + Comment string `json:"comment,omitempty"` + Password string `json:"password,omitempty"` + MustChangePassword bool `json:"mustChangePassword,omitempty"` + Disabled bool `json:"disabled,omitempty"` + DefaultWarehouse string `json:"defaultWarehouse,omitempty"` + DefaultNamespace string `json:"defaultNamespace,omitempty"` + DefaultRole string `json:"defaultRole,omitempty"` + DefaultSecondaryRoles string `json:"defaultSecondaryRoles,omitempty"` // ALL or NONE +} + +// CreateUserResponse represents the response from creating a user. +type CreateUserResponse struct { + User User `json:"user"` + Message string `json:"message,omitempty"` +} + +// DeleteUserOptions represents optional parameters for deleting a user. +type DeleteUserOptions struct { + IfExists bool `json:"ifExists,omitempty"` +} + +// SnowflakeError represents an error response from Snowflake REST API. +type SnowflakeError struct { + Code string `json:"code"` + ErrMsg string `json:"message"` +} + +// Message implements the ErrorResponse interface. +func (e *SnowflakeError) Message() string { + if e.ErrMsg != "" { + return e.ErrMsg + } + if e.Code != "" { + return e.Code + } + return "unknown error" +} + +// createUsersApiUrl creates the URL for the users REST API endpoint. +func createUsersApiUrl(accountUrl string) (*url.URL, error) { + stringUrl, err := url.JoinPath(accountUrl, "api/v2/users") + if err != nil { + return nil, err + } + + return url.Parse(stringUrl) +} + +// createUserApiUrl creates the URL for a specific user REST API endpoint. +func createUserApiUrl(accountUrl string, userName string) (*url.URL, error) { + stringUrl, err := url.JoinPath(accountUrl, "api/v2/users", userName) + if err != nil { + return nil, err + } + + return url.Parse(stringUrl) +} + +// doRequest is a helper method that wraps HTTP request logic for REST API calls. +func (c *Client) doRequest(ctx context.Context, method string, endpoint *url.URL, target interface{}, body interface{}, opts ...uhttp.RequestOption) (*http.Header, *v2.RateLimitDescription, error) { + var requestOptions []uhttp.RequestOption + requestOptions = append(requestOptions, + uhttp.WithAcceptJSONHeader(), + uhttp.WithHeader(AuthTypeHeaderKey, AuthTypeHeaderValue)) + + // Append any additional options passed in + requestOptions = append(requestOptions, opts...) + + if body != nil { + requestOptions = append(requestOptions, uhttp.WithContentTypeJSONHeader(), uhttp.WithJSONBody(body)) + } + + request, err := c.NewRequest(ctx, method, endpoint, requestOptions...) + if err != nil { + return nil, nil, fmt.Errorf("failed to create request: %w", err) + } + + var rateLimitData v2.RateLimitDescription + var errorResponse SnowflakeError + doOptions := []uhttp.DoOption{ + uhttp.WithRatelimitData(&rateLimitData), + uhttp.WithErrorResponse(&errorResponse), + } + if target != nil { + doOptions = append(doOptions, uhttp.WithJSONResponse(target)) + } + + response, err := c.Do(request, doOptions...) + if err != nil { + return nil, &rateLimitData, fmt.Errorf("request failed: %w", err) + } + defer func() { + _, _ = io.Copy(io.Discard, response.Body) + closeErr := response.Body.Close() + if closeErr != nil { + log.Printf("warning: failed to close response body: %v", closeErr) + } + }() + + if response.StatusCode >= 300 { + // Try to extract error message from response + if errorResponse.Code != "" || errorResponse.ErrMsg != "" { + return &response.Header, &rateLimitData, fmt.Errorf("snowflake API error: %s - %s", errorResponse.Code, errorResponse.Message()) + } + return &response.Header, &rateLimitData, fmt.Errorf("unexpected status code %d", response.StatusCode) + } + + return &response.Header, &rateLimitData, nil +} + +// CreateUserREST creates a new Snowflake user using the REST API. +// POST /api/v2/users. +func (c *Client) CreateUserREST(ctx context.Context, req *CreateUserRequest) (*User, *v2.RateLimitDescription, error) { + l := ctxzap.Extract(ctx) + + usersApiUrl, err := createUsersApiUrl(c.AccountUrl) + if err != nil { + return nil, nil, fmt.Errorf("failed to create users API URL: %w", err) + } + + var response CreateUserResponse + _, rateLimitDesc, err := c.doRequest(ctx, http.MethodPost, usersApiUrl, &response, req, uhttp.WithHeader(RoleHeaderKey, UserAdminRole)) + if err != nil { + l.Error("failed to create user", + zap.String("user_name", req.Name), + zap.Error(err), + ) + return nil, rateLimitDesc, err + } + + l.Debug("user created successfully", + zap.String("user_name", response.User.Username), + ) + + return &response.User, rateLimitDesc, nil +} + +// DeleteUserREST deletes a Snowflake user using the REST API. +// DELETE /api/v2/users/{name}. +func (c *Client) DeleteUserREST(ctx context.Context, userName string, options *DeleteUserOptions) (*v2.RateLimitDescription, error) { + l := ctxzap.Extract(ctx) + + userApiUrl, err := createUserApiUrl(c.AccountUrl, userName) + if err != nil { + return nil, fmt.Errorf("failed to create user API URL: %w", err) + } + + // Add query parameters if options are provided + if options != nil { + query := userApiUrl.Query() + if options.IfExists { + query.Set("ifExists", "true") + } + userApiUrl.RawQuery = query.Encode() + } + + _, rateLimitDesc, err := c.doRequest(ctx, http.MethodDelete, userApiUrl, nil, nil, uhttp.WithHeader(RoleHeaderKey, UserAdminRole)) + if err != nil { + l.Error("failed to delete user", + zap.String("user_name", userName), + zap.Error(err), + ) + return rateLimitDesc, err + } + + l.Debug("user deleted successfully", + zap.String("user_name", userName), + ) + + return rateLimitDesc, nil +} From d578bd811f6564cb3c55dc7d0fea44d9cd0514aa Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Tue, 10 Feb 2026 11:29:37 -0300 Subject: [PATCH 02/12] fix ci --- .github/workflows/capabilities_and_config.yaml | 11 +++++++---- .github/workflows/ci.yaml | 16 +++++++++++----- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/.github/workflows/capabilities_and_config.yaml b/.github/workflows/capabilities_and_config.yaml index c7ea94e4..a410d6fb 100644 --- a/.github/workflows/capabilities_and_config.yaml +++ b/.github/workflows/capabilities_and_config.yaml @@ -27,15 +27,18 @@ jobs: - name: Run and save config output run: ./connector config > config_schema.json + - name: Setup private key + run: | + echo "${{ secrets.BATON_PRIVATE_KEY }}" | base64 -d > /tmp/snowflake_key.p8 + chmod 600 /tmp/snowflake_key.p8 + - name: Run and save capabilities output env: BATON_ACCOUNT_IDENTIFIER: example BATON_USER_IDENTIFIER: example BATON_ACCOUNT_URL: https://example.snowflakecomputing.com - BATON_PRIVATE_KEY_PATH: ${{ runner.temp }}/baton_private_key.pem - run: | - openssl genrsa -out "$BATON_PRIVATE_KEY_PATH" 2048 - ./connector --sync-secrets capabilities > baton_capabilities.json + BATON_PRIVATE_KEY_PATH: /tmp/snowflake_key.p8 + run: ./connector --sync-secrets capabilities > baton_capabilities.json - name: Commit changes uses: EndBug/add-and-commit@v9 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9f373d16..61539aec 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -50,7 +50,6 @@ jobs: BATON_ACCOUNT_IDENTIFIER: ${{ secrets.BATON_ACCOUNT_IDENTIFIER }} BATON_USER_IDENTIFIER: ${{ secrets.BATON_USER_IDENTIFIER }} BATON_ACCOUNT_URL: ${{ secrets.BATON_ACCOUNT_URL }} - BATON_PRIVATE_KEY: ${{ secrets.BATON_PRIVATE_KEY }} # The following parameters are passed to grant/revoke commands # Change these to the correct IDs for your test data CONNECTOR_GRANT: 'grant:entitlement:group:1234:member:user:9876' @@ -73,13 +72,20 @@ jobs: - name: Install baton run: ./scripts/get-baton.sh && mv baton /usr/local/bin + - name: Setup private key + run: | + echo "${{ secrets.BATON_PRIVATE_KEY }}" | base64 -d > /tmp/snowflake_key.p8 + chmod 600 /tmp/snowflake_key.p8 + - name: Test Account Provisioning - uses: ConductorOne/github-workflows/.github/actions/account-provisioning@main + uses: ConductorOne/github-workflows/actions/account-provisioning@v4 with: connector: './baton-snowflake' - account-email: 'test-provisioning@example.com' - account-login: 'test-provisioning-user' - account-profile: '{"first_name": "Test", "last_name": "User", "name": "test-provisioning-user", "email": "test-provisioning@example.com"}' + account-email: 'testProvisioningUser@example.com' + account-login: 'testProvisioningUser' + account-profile: '{"first_name": "Test", "last_name": "User", "name": "testProvisioningUser", "email": "testProvisioningUser@example.com"}' account-type: 'user' search-method: 'email' + env: + BATON_PRIVATE_KEY_PATH: /tmp/snowflake_key.p8 From 744a318b1fd31993ca11db37cd027974407e2cf1 Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Tue, 10 Feb 2026 14:49:12 -0300 Subject: [PATCH 03/12] address minor pr comments --- pkg/snowflake/user_rest.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pkg/snowflake/user_rest.go b/pkg/snowflake/user_rest.go index 0353417e..6cc63646 100644 --- a/pkg/snowflake/user_rest.go +++ b/pkg/snowflake/user_rest.go @@ -96,7 +96,7 @@ func (c *Client) doRequest(ctx context.Context, method string, endpoint *url.URL request, err := c.NewRequest(ctx, method, endpoint, requestOptions...) if err != nil { - return nil, nil, fmt.Errorf("failed to create request: %w", err) + return nil, nil, fmt.Errorf("baton-snowflake: failed to create request: %w", err) } var rateLimitData v2.RateLimitDescription @@ -111,22 +111,25 @@ func (c *Client) doRequest(ctx context.Context, method string, endpoint *url.URL response, err := c.Do(request, doOptions...) if err != nil { - return nil, &rateLimitData, fmt.Errorf("request failed: %w", err) + return nil, &rateLimitData, fmt.Errorf("baton-snowflake: request failed: %w", err) } defer func() { + if response == nil || response.Body == nil { + return + } _, _ = io.Copy(io.Discard, response.Body) closeErr := response.Body.Close() if closeErr != nil { - log.Printf("warning: failed to close response body: %v", closeErr) + log.Printf("baton-snowflake: warning: failed to close response body: %v", closeErr) } }() if response.StatusCode >= 300 { // Try to extract error message from response if errorResponse.Code != "" || errorResponse.ErrMsg != "" { - return &response.Header, &rateLimitData, fmt.Errorf("snowflake API error: %s - %s", errorResponse.Code, errorResponse.Message()) + return &response.Header, &rateLimitData, fmt.Errorf("baton-snowflake: snowflake API error: %s - %s", errorResponse.Code, errorResponse.Message()) } - return &response.Header, &rateLimitData, fmt.Errorf("unexpected status code %d", response.StatusCode) + return &response.Header, &rateLimitData, fmt.Errorf("baton-snowflake: unexpected status code %d", response.StatusCode) } return &response.Header, &rateLimitData, nil @@ -139,20 +142,20 @@ func (c *Client) CreateUserREST(ctx context.Context, req *CreateUserRequest) (*U usersApiUrl, err := createUsersApiUrl(c.AccountUrl) if err != nil { - return nil, nil, fmt.Errorf("failed to create users API URL: %w", err) + return nil, nil, fmt.Errorf("baton-snowflake: failed to create users API URL: %w", err) } var response CreateUserResponse _, rateLimitDesc, err := c.doRequest(ctx, http.MethodPost, usersApiUrl, &response, req, uhttp.WithHeader(RoleHeaderKey, UserAdminRole)) if err != nil { - l.Error("failed to create user", + l.Error("baton-snowflake: failed to create user", zap.String("user_name", req.Name), zap.Error(err), ) return nil, rateLimitDesc, err } - l.Debug("user created successfully", + l.Debug("baton-snowflake: user created successfully", zap.String("user_name", response.User.Username), ) @@ -166,7 +169,7 @@ func (c *Client) DeleteUserREST(ctx context.Context, userName string, options *D userApiUrl, err := createUserApiUrl(c.AccountUrl, userName) if err != nil { - return nil, fmt.Errorf("failed to create user API URL: %w", err) + return nil, fmt.Errorf("baton-snowflake: failed to create user API URL: %w", err) } // Add query parameters if options are provided @@ -180,14 +183,14 @@ func (c *Client) DeleteUserREST(ctx context.Context, userName string, options *D _, rateLimitDesc, err := c.doRequest(ctx, http.MethodDelete, userApiUrl, nil, nil, uhttp.WithHeader(RoleHeaderKey, UserAdminRole)) if err != nil { - l.Error("failed to delete user", + l.Error("baton-snowflake: failed to delete user", zap.String("user_name", userName), zap.Error(err), ) return rateLimitDesc, err } - l.Debug("user deleted successfully", + l.Debug("baton-snowflake: user deleted successfully", zap.String("user_name", userName), ) From b891f5904870e54d7a6588dd040bbb480475cde8 Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Wed, 11 Feb 2026 18:03:50 -0300 Subject: [PATCH 04/12] fetch user after create --- pkg/connector/connector.go | 14 +++--- pkg/connector/users.go | 89 +++++++++++++++++++++++++++++++++-- pkg/snowflake/account_role.go | 29 ++++++++++-- pkg/snowflake/client.go | 10 ++++ pkg/snowflake/database.go | 6 ++- pkg/snowflake/user.go | 7 ++- pkg/snowflake/user_rest.go | 55 ++++++++++++++-------- 7 files changed, 173 insertions(+), 37 deletions(-) diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index 071138af..169c98d9 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -54,7 +54,7 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) "name": { DisplayName: "User Name", Required: true, - Description: "The name of the user (required). Can be provided via login or profile.name", + Description: "The name of the user (required - case-sensitive)", Placeholder: "username", Order: 0, Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ @@ -125,7 +125,7 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) DisplayName: "Disabled", Required: false, Description: "Whether the user account should be disabled", - Order: 8, + Order: 7, Field: &v2.ConnectorAccountCreationSchema_Field_BoolField{ BoolField: &v2.ConnectorAccountCreationSchema_BoolField{}, }, @@ -134,7 +134,7 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) DisplayName: "Must Change Password", Required: false, Description: "Whether the user must change their password on next login", - Order: 9, + Order: 8, Field: &v2.ConnectorAccountCreationSchema_Field_BoolField{ BoolField: &v2.ConnectorAccountCreationSchema_BoolField{}, }, @@ -144,7 +144,7 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) Required: false, Description: "The default warehouse to use when this user starts a session", Placeholder: "COMPUTE_WH", - Order: 10, + Order: 9, Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ StringField: &v2.ConnectorAccountCreationSchema_StringField{}, }, @@ -154,7 +154,7 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) Required: false, Description: "The default namespace to use when this user starts a session", Placeholder: "DATABASE.SCHEMA", - Order: 11, + Order: 10, Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ StringField: &v2.ConnectorAccountCreationSchema_StringField{}, }, @@ -164,7 +164,7 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) Required: false, Description: "The default role to use when this user starts a session", Placeholder: "PUBLIC", - Order: 12, + Order: 11, Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ StringField: &v2.ConnectorAccountCreationSchema_StringField{}, }, @@ -174,7 +174,7 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) Required: false, Description: "The default secondary roles of this user to use when starting a session. Valid values: ALL or NONE. Default is ALL.", Placeholder: "ALL", - Order: 13, + Order: 12, Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ StringField: &v2.ConnectorAccountCreationSchema_StringField{}, }, diff --git a/pkg/connector/users.go b/pkg/connector/users.go index aadf4221..3a446880 100644 --- a/pkg/connector/users.go +++ b/pkg/connector/users.go @@ -3,6 +3,9 @@ package connector import ( "context" "fmt" + "net/http" + "strings" + "time" v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" "github.com/conductorone/baton-sdk/pkg/annotations" @@ -244,8 +247,10 @@ func (o *userBuilder) CreateAccount( // Build create user request // name is the only required field for the create user request + // Quote the username to preserve case sensitivity (Snowflake stores unquoted identifiers in uppercase) + quotedUserName := fmt.Sprintf("\"%s\"", userName) createReq := &snowflake.CreateUserRequest{ - Name: userName, + Name: quotedUserName, } // Extract optional fields from profile (login and email are optional - only set if provided in profile) @@ -275,7 +280,7 @@ func (o *userBuilder) CreateAccount( } // Create user via REST API - user, rateLimitDesc, err := o.client.CreateUserREST(ctx, createReq) + _, rateLimitDesc, err := o.client.CreateUserREST(ctx, createReq) if err != nil { l.Error("failed to create user", zap.String("user_name", userName), @@ -288,6 +293,19 @@ func (o *userBuilder) CreateAccount( return nil, nil, annos, wrapError(err, "failed to create user") } + user, err := o.fetchUserWithSQLRetry(ctx, userName) + if err != nil { + l.Error("failed to fetch user after creation", + zap.String("user_name", userName), + zap.Error(err), + ) + annos := annotations.Annotations{} + if rateLimitDesc != nil { + annos.Update(rateLimitDesc) + } + return nil, nil, annos, wrapError(err, "failed to fetch user after creation") + } + // Build resource for the new user resource, err := userResource(ctx, user, o.syncSecrets) if err != nil { @@ -312,6 +330,68 @@ func (o *userBuilder) CreateAccount( return result, plaintextData, annos, nil } +// fetchUserWithSQLRetry attempts to fetch a user using the SQL API with retry logic for 422 errors. +// Retries up to 5 times with exponential backoff if we get a 422 Unprocessable Entity error. +func (o *userBuilder) fetchUserWithSQLRetry(ctx context.Context, userName string) (*snowflake.User, error) { + l := ctxzap.Extract(ctx) + maxRetries := 5 + baseDelay := 500 * time.Millisecond + + for attempt := 0; attempt <= maxRetries; attempt++ { + user, resp, err := o.client.GetUser(ctx, userName) + if err == nil && resp != nil && resp.StatusCode == http.StatusOK { + l.Debug("user fetched successfully via SQL API", + zap.String("user_name", userName), + ) + return user, nil + } + + // Check if we got a 422 error + is422 := false + if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { + is422 = true + } else if err != nil { + errStr := strings.ToLower(err.Error()) + is422 = strings.Contains(errStr, "422") || strings.Contains(errStr, "unprocessable entity") + } + + // If it's not a 422 error, or we've exhausted retries, return the error + if !is422 || attempt >= maxRetries { + if err != nil { + return nil, err + } + if resp != nil && resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("baton-snowflake: unexpected status code %d when fetching user", resp.StatusCode) + } + return nil, fmt.Errorf("baton-snowflake: failed to fetch user") + } + + // Calculate exponential backoff: baseDelay * 2^attempt + delay := baseDelay + for i := 0; i < attempt; i++ { + delay *= 2 + } + + l.Debug("user fetch returned 422, retrying with SQL API", + zap.String("user_name", userName), + zap.Int("attempt", attempt+1), + zap.Int("max_retries", maxRetries), + zap.Duration("delay", delay), + ) + + // Wait before retrying + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + // Continue to next iteration + } + } + + // This should not be reached, but handle it just in case + return nil, fmt.Errorf("baton-snowflake: failed to fetch user after %d retries", maxRetries) +} + // Delete deletes a Snowflake user using the REST API. func (o *userBuilder) Delete(ctx context.Context, resourceId *v2.ResourceId, parentResourceID *v2.ResourceId) (annotations.Annotations, error) { l := ctxzap.Extract(ctx) @@ -321,9 +401,12 @@ func (o *userBuilder) Delete(ctx context.Context, resourceId *v2.ResourceId, par return nil, fmt.Errorf("baton-snowflake: user name is required") } + // Quote the username to match the case-sensitive identifier created with quotes + // This ensures we delete the exact case-sensitive identifier + quotedUserName := fmt.Sprintf("\"%s\"", userName) // Delete user via REST API // Using default options (ifExists=false) - _, err := o.client.DeleteUserREST(ctx, userName, nil) + _, err := o.client.DeleteUserREST(ctx, quotedUserName, nil) if err != nil { l.Error("failed to delete user", zap.String("user_name", userName), diff --git a/pkg/snowflake/account_role.go b/pkg/snowflake/account_role.go index b0076bc3..3cea4a0e 100644 --- a/pkg/snowflake/account_role.go +++ b/pkg/snowflake/account_role.go @@ -84,18 +84,20 @@ func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) if err != nil { return nil, nil, err } + defer closeResponseBody(resp) l := ctxzap.Extract(ctx) l.Debug("ListAccountRoles", zap.String("response.code", response.Code), zap.String("response.message", response.Message)) req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, resp, err + return nil, nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { return nil, resp, err } + defer closeResponseBody(resp) accountRoles, err := response.GetAccountRoles() if err != nil { @@ -120,17 +122,21 @@ func (c *Client) ListAccountRoleGrantees(ctx context.Context, roleName string) ( if err != nil { return nil, nil, err } + defer closeResponseBody(resp) req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, resp, err + return nil, nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { return nil, resp, err } + defer closeResponseBody(resp) + + accountRoleGrantees := response.GetAccountRoleGrantees() - return response.GetAccountRoleGrantees(), resp, nil + return accountRoleGrantees, resp, nil } func (c *Client) GetAccountRole(ctx context.Context, roleName string) (*AccountRole, *http.Response, error) { @@ -148,6 +154,7 @@ func (c *Client) GetAccountRole(ctx context.Context, roleName string) (*AccountR if err != nil { return nil, nil, err } + defer closeResponseBody(resp) accountRoles, err := response.GetAccountRoles() if err != nil { @@ -171,7 +178,13 @@ func (c *Client) GrantAccountRole(ctx context.Context, roleName, userName string return nil, err } - return c.Do(req) + resp, err := c.Do(req) + if err != nil { + return resp, err + } + defer closeResponseBody(resp) + + return resp, nil } func (c *Client) RevokeAccountRole(ctx context.Context, roleName, userName string) (*http.Response, error) { @@ -184,5 +197,11 @@ func (c *Client) RevokeAccountRole(ctx context.Context, roleName, userName strin return nil, err } - return c.Do(req) + resp, err := c.Do(req) + if err != nil { + return resp, err + } + defer closeResponseBody(resp) + + return resp, nil } diff --git a/pkg/snowflake/client.go b/pkg/snowflake/client.go index 7feb76d8..f34a0fb6 100644 --- a/pkg/snowflake/client.go +++ b/pkg/snowflake/client.go @@ -3,6 +3,7 @@ package snowflake import ( "context" "fmt" + "io" "net/http" "net/url" "reflect" @@ -242,3 +243,12 @@ func Contains[T comparable](ts []T, val T) bool { } return false } + +// closeResponseBody drains and closes the response body if it exists. +// This ensures proper resource cleanup and allows connection reuse. +func closeResponseBody(resp *http.Response) { + if resp != nil && resp.Body != nil { + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + } +} diff --git a/pkg/snowflake/database.go b/pkg/snowflake/database.go index b99b0887..de4c2abe 100644 --- a/pkg/snowflake/database.go +++ b/pkg/snowflake/database.go @@ -79,18 +79,20 @@ func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([ if err != nil { return nil, nil, err } + defer closeResponseBody(resp) l := ctxzap.Extract(ctx) l.Debug("ListDatabases", zap.String("response.code", response.Code), zap.String("response.message", response.Message)) req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, resp, err + return nil, nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { return nil, resp, err } + defer closeResponseBody(resp) dbs, err := response.GetDatabases() if err != nil { @@ -118,11 +120,13 @@ func (c *Client) GetDatabase(ctx context.Context, name string) (*Database, *http } return nil, resp, err } + defer closeResponseBody(resp) databases, err := response.GetDatabases() if err != nil { return nil, resp, err } + if len(databases) == 0 { return nil, resp, fmt.Errorf("database with name %s not found", name) } else if len(databases) > 1 { diff --git a/pkg/snowflake/user.go b/pkg/snowflake/user.go index 47a8025d..eafa7c4a 100644 --- a/pkg/snowflake/user.go +++ b/pkg/snowflake/user.go @@ -31,9 +31,11 @@ var ( } // Sadly snowflake is inconsistent and returns different set of columns for DESC USER. + // These fields are ignored when parsing DESCRIBE USER output. ignoredUserStructFieldsForDescribeOperation = []string{ "HasRSAPublicKey", "HasPassword", + "LastSuccessLogin", // May not be present for newly created users } secretStructFieldToColumnMap = map[string]string{ @@ -199,15 +201,17 @@ func (c *Client) ListUsers(ctx context.Context, cursor string, limit int) ([]Use if err != nil { return nil, resp, err } + defer closeResponseBody(resp) req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, resp, err + return nil, nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { return nil, resp, err } + defer closeResponseBody(resp) users, err := response.GetUsers() if err != nil { @@ -232,6 +236,7 @@ func (c *Client) GetUser(ctx context.Context, username string) (*User, *http.Res if err != nil { return nil, resp, err } + defer closeResponseBody(resp) user, err := response.GetUser() if err != nil { diff --git a/pkg/snowflake/user_rest.go b/pkg/snowflake/user_rest.go index 6cc63646..1f620799 100644 --- a/pkg/snowflake/user_rest.go +++ b/pkg/snowflake/user_rest.go @@ -34,7 +34,8 @@ type CreateUserRequest struct { // CreateUserResponse represents the response from creating a user. type CreateUserResponse struct { - User User `json:"user"` + Status string `json:"status,omitempty"` + Code string `json:"code,omitempty"` Message string `json:"message,omitempty"` } @@ -71,6 +72,8 @@ func createUsersApiUrl(accountUrl string) (*url.URL, error) { } // createUserApiUrl creates the URL for a specific user REST API endpoint. +// The userName should be the actual identifier (case-sensitive if created with quotes). +// url.JoinPath will properly encode the path segment. func createUserApiUrl(accountUrl string, userName string) (*url.URL, error) { stringUrl, err := url.JoinPath(accountUrl, "api/v2/users", userName) if err != nil { @@ -81,7 +84,15 @@ func createUserApiUrl(accountUrl string, userName string) (*url.URL, error) { } // doRequest is a helper method that wraps HTTP request logic for REST API calls. -func (c *Client) doRequest(ctx context.Context, method string, endpoint *url.URL, target interface{}, body interface{}, opts ...uhttp.RequestOption) (*http.Header, *v2.RateLimitDescription, error) { +// Returns the response headers, rate limit description, status code, and error. +func (c *Client) doRequest( + ctx context.Context, + method string, + endpoint *url.URL, + target interface{}, + body interface{}, + opts ...uhttp.RequestOption, +) (*http.Header, *v2.RateLimitDescription, int, error) { var requestOptions []uhttp.RequestOption requestOptions = append(requestOptions, uhttp.WithAcceptJSONHeader(), @@ -96,7 +107,7 @@ func (c *Client) doRequest(ctx context.Context, method string, endpoint *url.URL request, err := c.NewRequest(ctx, method, endpoint, requestOptions...) if err != nil { - return nil, nil, fmt.Errorf("baton-snowflake: failed to create request: %w", err) + return nil, nil, 0, fmt.Errorf("baton-snowflake: failed to create request: %w", err) } var rateLimitData v2.RateLimitDescription @@ -111,7 +122,7 @@ func (c *Client) doRequest(ctx context.Context, method string, endpoint *url.URL response, err := c.Do(request, doOptions...) if err != nil { - return nil, &rateLimitData, fmt.Errorf("baton-snowflake: request failed: %w", err) + return nil, &rateLimitData, 0, fmt.Errorf("baton-snowflake: request failed: %w", err) } defer func() { if response == nil || response.Body == nil { @@ -124,42 +135,46 @@ func (c *Client) doRequest(ctx context.Context, method string, endpoint *url.URL } }() - if response.StatusCode >= 300 { + statusCode := response.StatusCode + if statusCode >= 300 { // Try to extract error message from response if errorResponse.Code != "" || errorResponse.ErrMsg != "" { - return &response.Header, &rateLimitData, fmt.Errorf("baton-snowflake: snowflake API error: %s - %s", errorResponse.Code, errorResponse.Message()) + return &response.Header, &rateLimitData, statusCode, fmt.Errorf("baton-snowflake: snowflake API error: %s - %s", errorResponse.Code, errorResponse.Message()) } - return &response.Header, &rateLimitData, fmt.Errorf("baton-snowflake: unexpected status code %d", response.StatusCode) + return &response.Header, &rateLimitData, statusCode, fmt.Errorf("baton-snowflake: unexpected status code %d", statusCode) } - return &response.Header, &rateLimitData, nil + return &response.Header, &rateLimitData, statusCode, nil } // CreateUserREST creates a new Snowflake user using the REST API. // POST /api/v2/users. -func (c *Client) CreateUserREST(ctx context.Context, req *CreateUserRequest) (*User, *v2.RateLimitDescription, error) { +// Returns completed=true if status code is 200, false if 202 (accepted but not completed). +func (c *Client) CreateUserREST(ctx context.Context, req *CreateUserRequest) (bool, *v2.RateLimitDescription, error) { l := ctxzap.Extract(ctx) usersApiUrl, err := createUsersApiUrl(c.AccountUrl) if err != nil { - return nil, nil, fmt.Errorf("baton-snowflake: failed to create users API URL: %w", err) + return false, nil, fmt.Errorf("baton-snowflake: failed to create users API URL: %w", err) } var response CreateUserResponse - _, rateLimitDesc, err := c.doRequest(ctx, http.MethodPost, usersApiUrl, &response, req, uhttp.WithHeader(RoleHeaderKey, UserAdminRole)) + _, rateLimitDesc, statusCode, err := c.doRequest(ctx, http.MethodPost, usersApiUrl, &response, req, uhttp.WithHeader(RoleHeaderKey, UserAdminRole)) if err != nil { - l.Error("baton-snowflake: failed to create user", - zap.String("user_name", req.Name), - zap.Error(err), - ) - return nil, rateLimitDesc, err + return false, rateLimitDesc, err } - l.Debug("baton-snowflake: user created successfully", - zap.String("user_name", response.User.Username), + completed := statusCode == http.StatusOK + + l.Debug("baton-snowflake: user creation request completed", + zap.String("user_name", req.Name), + zap.Int("status_code", statusCode), + zap.Bool("completed", completed), + zap.String("status", response.Status), + zap.String("message", response.Message), ) - return &response.User, rateLimitDesc, nil + return completed, rateLimitDesc, nil } // DeleteUserREST deletes a Snowflake user using the REST API. @@ -181,7 +196,7 @@ func (c *Client) DeleteUserREST(ctx context.Context, userName string, options *D userApiUrl.RawQuery = query.Encode() } - _, rateLimitDesc, err := c.doRequest(ctx, http.MethodDelete, userApiUrl, nil, nil, uhttp.WithHeader(RoleHeaderKey, UserAdminRole)) + _, rateLimitDesc, _, err := c.doRequest(ctx, http.MethodDelete, userApiUrl, nil, nil, uhttp.WithHeader(RoleHeaderKey, UserAdminRole)) if err != nil { l.Error("baton-snowflake: failed to delete user", zap.String("user_name", userName), From 65339bcd6b77273bfdb378b767e06c4291e3e05b Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Wed, 11 Feb 2026 18:17:01 -0300 Subject: [PATCH 05/12] avoid returning complete http response --- pkg/connector/account_roles.go | 8 ++--- pkg/connector/connector.go | 2 +- pkg/connector/databases.go | 2 +- pkg/connector/users.go | 12 ++++---- pkg/snowflake/account_role.go | 55 +++++++++++++++++----------------- pkg/snowflake/database.go | 27 ++++++++--------- pkg/snowflake/user.go | 29 ++++++++++-------- 7 files changed, 68 insertions(+), 67 deletions(-) diff --git a/pkg/connector/account_roles.go b/pkg/connector/account_roles.go index a5b6c264..61ce2865 100644 --- a/pkg/connector/account_roles.go +++ b/pkg/connector/account_roles.go @@ -46,7 +46,7 @@ func (o *accountRoleBuilder) List(ctx context.Context, parentResourceID *v2.Reso return nil, nil, wrapError(err, "failed to get next page offset") } - accountRoles, _, err := o.client.ListAccountRoles(ctx, cursor, resourcePageSize) + accountRoles, err := o.client.ListAccountRoles(ctx, cursor, resourcePageSize) if err != nil { return nil, nil, wrapError(err, "failed to list account roles") } @@ -95,7 +95,7 @@ func (o *accountRoleBuilder) Entitlements(_ context.Context, resource *v2.Resour } func (o *accountRoleBuilder) Grants(ctx context.Context, resource *v2.Resource, _ rs.SyncOpAttrs) ([]*v2.Grant, *rs.SyncOpResults, error) { - accountRoleGrantees, _, err := o.client.ListAccountRoleGrantees(ctx, resource.DisplayName) + accountRoleGrantees, err := o.client.ListAccountRoleGrantees(ctx, resource.DisplayName) if err != nil { return nil, nil, wrapError(err, "failed to list account role grantees") } @@ -138,7 +138,7 @@ func (o *accountRoleBuilder) Grant(ctx context.Context, principal *v2.Resource, return nil, err } - _, err := o.client.GrantAccountRole(ctx, entitlement.Resource.Id.Resource, principal.Id.Resource) + err := o.client.GrantAccountRole(ctx, entitlement.Resource.Id.Resource, principal.Id.Resource) if err != nil { err = wrapError(err, "failed to grant account role") @@ -167,7 +167,7 @@ func (o *accountRoleBuilder) Revoke(ctx context.Context, grant *v2.Grant) (annot return nil, err } - _, err := o.client.RevokeAccountRole(ctx, grant.Entitlement.Resource.Id.Resource, grant.Principal.Id.Resource) + err := o.client.RevokeAccountRole(ctx, grant.Entitlement.Resource.Id.Resource, grant.Principal.Id.Resource) if err != nil { err = wrapError(err, "failed to revoke account role") diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index 169c98d9..b84f0ea0 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -187,7 +187,7 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) // Validate is called to ensure that the connector is properly configured. It should exercise any API credentials // to be sure that they are valid. func (d *Connector) Validate(ctx context.Context) (annotations.Annotations, error) { - users, _, err := d.Client.ListUsers(ctx, "", 1) + users, err := d.Client.ListUsers(ctx, "", 1) if err != nil { return nil, err } diff --git a/pkg/connector/databases.go b/pkg/connector/databases.go index b7f60920..e6b6bed4 100644 --- a/pkg/connector/databases.go +++ b/pkg/connector/databases.go @@ -61,7 +61,7 @@ func (o *databaseBuilder) List(ctx context.Context, parentResourceID *v2.Resourc return nil, nil, wrapError(err, "failed to get next page offset") } - databases, _, err := o.client.ListDatabases(ctx, cursor, resourcePageSize) + databases, err := o.client.ListDatabases(ctx, cursor, resourcePageSize) if err != nil { return nil, nil, wrapError(err, "failed to list databases") } diff --git a/pkg/connector/users.go b/pkg/connector/users.go index 3a446880..06fe798e 100644 --- a/pkg/connector/users.go +++ b/pkg/connector/users.go @@ -168,7 +168,7 @@ func (o *userBuilder) List(ctx context.Context, parentResourceID *v2.ResourceId, return nil, nil, wrapError(err, "failed to get next page cursor") } - users, _, err := o.client.ListUsers(ctx, cursor, resourcePageSize) + users, err := o.client.ListUsers(ctx, cursor, resourcePageSize) if err != nil { return nil, nil, wrapError(err, "failed to list users") } @@ -338,8 +338,8 @@ func (o *userBuilder) fetchUserWithSQLRetry(ctx context.Context, userName string baseDelay := 500 * time.Millisecond for attempt := 0; attempt <= maxRetries; attempt++ { - user, resp, err := o.client.GetUser(ctx, userName) - if err == nil && resp != nil && resp.StatusCode == http.StatusOK { + user, statusCode, err := o.client.GetUser(ctx, userName) + if err == nil && statusCode == http.StatusOK { l.Debug("user fetched successfully via SQL API", zap.String("user_name", userName), ) @@ -348,7 +348,7 @@ func (o *userBuilder) fetchUserWithSQLRetry(ctx context.Context, userName string // Check if we got a 422 error is422 := false - if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { + if statusCode == http.StatusUnprocessableEntity { is422 = true } else if err != nil { errStr := strings.ToLower(err.Error()) @@ -360,8 +360,8 @@ func (o *userBuilder) fetchUserWithSQLRetry(ctx context.Context, userName string if err != nil { return nil, err } - if resp != nil && resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("baton-snowflake: unexpected status code %d when fetching user", resp.StatusCode) + if statusCode != http.StatusOK { + return nil, fmt.Errorf("baton-snowflake: unexpected status code %d when fetching user", statusCode) } return nil, fmt.Errorf("baton-snowflake: failed to fetch user") } diff --git a/pkg/snowflake/account_role.go b/pkg/snowflake/account_role.go index 3cea4a0e..faca54f4 100644 --- a/pkg/snowflake/account_role.go +++ b/pkg/snowflake/account_role.go @@ -3,7 +3,6 @@ package snowflake import ( "context" "fmt" - "net/http" "github.com/conductorone/baton-sdk/pkg/uhttp" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" @@ -65,7 +64,7 @@ func (r *ListAccountRoleGranteesRawResponse) GetAccountRoleGrantees() []AccountR return accountRoleGrantees } -func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) ([]AccountRole, *http.Response, error) { +func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) ([]AccountRole, error) { var queries []string if cursor != "" { @@ -76,13 +75,13 @@ func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListAccountRolesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { - return nil, nil, err + return nil, err } defer closeResponseBody(resp) @@ -91,117 +90,117 @@ func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, nil, err + return nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { - return nil, resp, err + return nil, err } defer closeResponseBody(resp) accountRoles, err := response.GetAccountRoles() if err != nil { - return nil, resp, err + return nil, err } - return accountRoles, resp, nil + return accountRoles, nil } -func (c *Client) ListAccountRoleGrantees(ctx context.Context, roleName string) ([]AccountRoleGrantee, *http.Response, error) { +func (c *Client) ListAccountRoleGrantees(ctx context.Context, roleName string) ([]AccountRoleGrantee, error) { queries := []string{ fmt.Sprintf("SHOW GRANTS OF ROLE \"%s\";", roleName), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListAccountRoleGranteesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { - return nil, nil, err + return nil, err } defer closeResponseBody(resp) req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, nil, err + return nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { - return nil, resp, err + return nil, err } defer closeResponseBody(resp) accountRoleGrantees := response.GetAccountRoleGrantees() - return accountRoleGrantees, resp, nil + return accountRoleGrantees, nil } -func (c *Client) GetAccountRole(ctx context.Context, roleName string) (*AccountRole, *http.Response, error) { +func (c *Client) GetAccountRole(ctx context.Context, roleName string) (*AccountRole, error) { queries := []string{ fmt.Sprintf("SHOW ROLES LIKE '%s' LIMIT 1;", roleName), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListAccountRolesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { - return nil, nil, err + return nil, err } defer closeResponseBody(resp) accountRoles, err := response.GetAccountRoles() if err != nil { - return nil, resp, err + return nil, err } if len(accountRoles) == 0 { - return nil, resp, nil + return nil, nil } - return &accountRoles[0], resp, nil + return &accountRoles[0], nil } -func (c *Client) GrantAccountRole(ctx context.Context, roleName, userName string) (*http.Response, error) { +func (c *Client) GrantAccountRole(ctx context.Context, roleName, userName string) error { queries := []string{ fmt.Sprintf("GRANT ROLE \"%s\" TO USER \"%s\";", roleName, userName), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, err + return err } resp, err := c.Do(req) if err != nil { - return resp, err + return err } defer closeResponseBody(resp) - return resp, nil + return nil } -func (c *Client) RevokeAccountRole(ctx context.Context, roleName, userName string) (*http.Response, error) { +func (c *Client) RevokeAccountRole(ctx context.Context, roleName, userName string) error { queries := []string{ fmt.Sprintf("REVOKE ROLE \"%s\" FROM USER \"%s\";", roleName, userName), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, err + return err } resp, err := c.Do(req) if err != nil { - return resp, err + return err } defer closeResponseBody(resp) - return resp, nil + return nil } diff --git a/pkg/snowflake/database.go b/pkg/snowflake/database.go index de4c2abe..8b5a161c 100644 --- a/pkg/snowflake/database.go +++ b/pkg/snowflake/database.go @@ -3,7 +3,6 @@ package snowflake import ( "context" "fmt" - "net/http" "strings" "github.com/conductorone/baton-sdk/pkg/uhttp" @@ -60,7 +59,7 @@ func (r *ListDatabasesRawResponse) GetDatabases() ([]Database, error) { return databases, nil } -func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([]Database, *http.Response, error) { +func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([]Database, error) { var queries []string if cursor != "" { @@ -71,13 +70,13 @@ func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([ req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListDatabasesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { - return nil, nil, err + return nil, err } defer closeResponseBody(resp) @@ -86,30 +85,30 @@ func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([ req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, nil, err + return nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { - return nil, resp, err + return nil, err } defer closeResponseBody(resp) dbs, err := response.GetDatabases() if err != nil { - return nil, resp, err + return nil, err } - return dbs, resp, nil + return dbs, nil } -func (c *Client) GetDatabase(ctx context.Context, name string) (*Database, *http.Response, error) { +func (c *Client) GetDatabase(ctx context.Context, name string) (*Database, error) { queries := []string{ fmt.Sprintf("SHOW DATABASES LIKE '%s' LIMIT 1;", name), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListDatabasesRawResponse @@ -124,14 +123,14 @@ func (c *Client) GetDatabase(ctx context.Context, name string) (*Database, *http databases, err := response.GetDatabases() if err != nil { - return nil, resp, err + return nil, err } if len(databases) == 0 { - return nil, resp, fmt.Errorf("database with name %s not found", name) + return nil, fmt.Errorf("database with name %s not found", name) } else if len(databases) > 1 { - return nil, resp, fmt.Errorf("expected 1 database with name %s, got %d", name, len(databases)) + return nil, fmt.Errorf("expected 1 database with name %s, got %d", name, len(databases)) } - return &databases[0], resp, nil + return &databases[0], nil } diff --git a/pkg/snowflake/user.go b/pkg/snowflake/user.go index eafa7c4a..3ad679e8 100644 --- a/pkg/snowflake/user.go +++ b/pkg/snowflake/user.go @@ -3,7 +3,6 @@ package snowflake import ( "context" "fmt" - "net/http" "reflect" "strings" "time" @@ -183,7 +182,7 @@ func (r *GetUserRawResponse) GetValueByColumnName(columnName string) (string, bo return "", false } -func (c *Client) ListUsers(ctx context.Context, cursor string, limit int) ([]User, *http.Response, error) { +func (c *Client) ListUsers(ctx context.Context, cursor string, limit int) ([]User, error) { var queries []string if cursor != "" { queries = append(queries, fmt.Sprintf("SHOW USERS LIMIT %d FROM '%s';", limit, cursor)) @@ -193,57 +192,61 @@ func (c *Client) ListUsers(ctx context.Context, cursor string, limit int) ([]Use req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListUsersRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { - return nil, resp, err + return nil, err } defer closeResponseBody(resp) req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, nil, err + return nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { - return nil, resp, err + return nil, err } defer closeResponseBody(resp) users, err := response.GetUsers() if err != nil { - return nil, resp, err + return nil, err } - return users, resp, nil + return users, nil } -func (c *Client) GetUser(ctx context.Context, username string) (*User, *http.Response, error) { +func (c *Client) GetUser(ctx context.Context, username string) (*User, int, error) { queries := []string{ fmt.Sprintf("DESCRIBE USER \"%s\";", username), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, 0, err } var response GetUserRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) if err != nil { - return nil, resp, err + statusCode := 0 + if resp != nil { + statusCode = resp.StatusCode + } + return nil, statusCode, err } defer closeResponseBody(resp) user, err := response.GetUser() if err != nil { - return nil, resp, err + return nil, resp.StatusCode, err } - return user, resp, nil + return user, resp.StatusCode, nil } func (r *ListSecretsRawResponse) ListSecrets() ([]Secret, error) { From 2b30f572d4e5f875855a0c14402d5b001510906e Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Thu, 12 Feb 2026 10:26:57 -0300 Subject: [PATCH 06/12] ensure close body --- pkg/snowflake/account_role.go | 14 +++++++------- pkg/snowflake/database.go | 6 +++--- pkg/snowflake/secrets.go | 4 ++-- pkg/snowflake/user.go | 6 +++--- pkg/snowflake/user_rest.go | 6 +++--- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pkg/snowflake/account_role.go b/pkg/snowflake/account_role.go index faca54f4..93db859e 100644 --- a/pkg/snowflake/account_role.go +++ b/pkg/snowflake/account_role.go @@ -80,10 +80,10 @@ func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) var response ListAccountRolesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { return nil, err } - defer closeResponseBody(resp) l := ctxzap.Extract(ctx) l.Debug("ListAccountRoles", zap.String("response.code", response.Code), zap.String("response.message", response.Message)) @@ -93,10 +93,10 @@ func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) return nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { return nil, err } - defer closeResponseBody(resp) accountRoles, err := response.GetAccountRoles() if err != nil { @@ -118,20 +118,20 @@ func (c *Client) ListAccountRoleGrantees(ctx context.Context, roleName string) ( var response ListAccountRoleGranteesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { return nil, err } - defer closeResponseBody(resp) req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { return nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { return nil, err } - defer closeResponseBody(resp) accountRoleGrantees := response.GetAccountRoleGrantees() @@ -150,10 +150,10 @@ func (c *Client) GetAccountRole(ctx context.Context, roleName string) (*AccountR var response ListAccountRolesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { return nil, err } - defer closeResponseBody(resp) accountRoles, err := response.GetAccountRoles() if err != nil { @@ -178,10 +178,10 @@ func (c *Client) GrantAccountRole(ctx context.Context, roleName, userName string } resp, err := c.Do(req) + defer closeResponseBody(resp) if err != nil { return err } - defer closeResponseBody(resp) return nil } @@ -197,10 +197,10 @@ func (c *Client) RevokeAccountRole(ctx context.Context, roleName, userName strin } resp, err := c.Do(req) + defer closeResponseBody(resp) if err != nil { return err } - defer closeResponseBody(resp) return nil } diff --git a/pkg/snowflake/database.go b/pkg/snowflake/database.go index 8b5a161c..0daad95b 100644 --- a/pkg/snowflake/database.go +++ b/pkg/snowflake/database.go @@ -75,10 +75,10 @@ func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([ var response ListDatabasesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { return nil, err } - defer closeResponseBody(resp) l := ctxzap.Extract(ctx) l.Debug("ListDatabases", zap.String("response.code", response.Code), zap.String("response.message", response.Message)) @@ -88,10 +88,10 @@ func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([ return nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { return nil, err } - defer closeResponseBody(resp) dbs, err := response.GetDatabases() if err != nil { @@ -113,13 +113,13 @@ func (c *Client) GetDatabase(ctx context.Context, name string) (*Database, error var response ListDatabasesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { if IsUnprocessableEntity(resp, err) { return nil, resp, nil } return nil, resp, err } - defer closeResponseBody(resp) databases, err := response.GetDatabases() if err != nil { diff --git a/pkg/snowflake/secrets.go b/pkg/snowflake/secrets.go index fd5aa9bc..68ea873e 100644 --- a/pkg/snowflake/secrets.go +++ b/pkg/snowflake/secrets.go @@ -29,6 +29,7 @@ func (c *Client) ListSecrets(ctx context.Context, database string) ([]Secret, er var response ListSecretsRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { var errMsg struct { @@ -55,7 +56,6 @@ func (c *Client) ListSecrets(ctx context.Context, database string) ([]Secret, er return nil, err } - defer resp.Body.Close() secrets, err := response.ListSecrets() if err != nil { @@ -77,10 +77,10 @@ func (c *Client) UserRsa(ctx context.Context, username string) (*UserRsa, error) var response RsaGetUserRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { return nil, err } - defer resp.Body.Close() secrets, err := response.GetUserRsa(ctx) if err != nil { diff --git a/pkg/snowflake/user.go b/pkg/snowflake/user.go index 3ad679e8..0c6e1692 100644 --- a/pkg/snowflake/user.go +++ b/pkg/snowflake/user.go @@ -197,20 +197,20 @@ func (c *Client) ListUsers(ctx context.Context, cursor string, limit int) ([]Use var response ListUsersRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { return nil, err } - defer closeResponseBody(resp) req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { return nil, err } resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { return nil, err } - defer closeResponseBody(resp) users, err := response.GetUsers() if err != nil { @@ -232,6 +232,7 @@ func (c *Client) GetUser(ctx context.Context, username string) (*User, int, erro var response GetUserRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp) if err != nil { statusCode := 0 if resp != nil { @@ -239,7 +240,6 @@ func (c *Client) GetUser(ctx context.Context, username string) (*User, int, erro } return nil, statusCode, err } - defer closeResponseBody(resp) user, err := response.GetUser() if err != nil { diff --git a/pkg/snowflake/user_rest.go b/pkg/snowflake/user_rest.go index 1f620799..146fb41d 100644 --- a/pkg/snowflake/user_rest.go +++ b/pkg/snowflake/user_rest.go @@ -121,9 +121,6 @@ func (c *Client) doRequest( } response, err := c.Do(request, doOptions...) - if err != nil { - return nil, &rateLimitData, 0, fmt.Errorf("baton-snowflake: request failed: %w", err) - } defer func() { if response == nil || response.Body == nil { return @@ -134,6 +131,9 @@ func (c *Client) doRequest( log.Printf("baton-snowflake: warning: failed to close response body: %v", closeErr) } }() + if err != nil { + return nil, &rateLimitData, 0, fmt.Errorf("baton-snowflake: request failed: %w", err) + } statusCode := response.StatusCode if statusCode >= 300 { From fa96a020535f4aa0342a4df5ee2ceb51187a582d Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Thu, 12 Feb 2026 10:34:31 -0300 Subject: [PATCH 07/12] use credential opts force change password --- pkg/connector/connector.go | 17 ++++------------- pkg/connector/users.go | 6 +----- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index b84f0ea0..4d23922d 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -130,21 +130,12 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) BoolField: &v2.ConnectorAccountCreationSchema_BoolField{}, }, }, - "must_change_password": { - DisplayName: "Must Change Password", - Required: false, - Description: "Whether the user must change their password on next login", - Order: 8, - Field: &v2.ConnectorAccountCreationSchema_Field_BoolField{ - BoolField: &v2.ConnectorAccountCreationSchema_BoolField{}, - }, - }, "default_warehouse": { DisplayName: "Default Warehouse", Required: false, Description: "The default warehouse to use when this user starts a session", Placeholder: "COMPUTE_WH", - Order: 9, + Order: 8, Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ StringField: &v2.ConnectorAccountCreationSchema_StringField{}, }, @@ -154,7 +145,7 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) Required: false, Description: "The default namespace to use when this user starts a session", Placeholder: "DATABASE.SCHEMA", - Order: 10, + Order: 9, Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ StringField: &v2.ConnectorAccountCreationSchema_StringField{}, }, @@ -164,7 +155,7 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) Required: false, Description: "The default role to use when this user starts a session", Placeholder: "PUBLIC", - Order: 11, + Order: 10, Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ StringField: &v2.ConnectorAccountCreationSchema_StringField{}, }, @@ -174,7 +165,7 @@ func (d *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) Required: false, Description: "The default secondary roles of this user to use when starting a session. Valid values: ALL or NONE. Default is ALL.", Placeholder: "ALL", - Order: 12, + Order: 11, Field: &v2.ConnectorAccountCreationSchema_Field_StringField{ StringField: &v2.ConnectorAccountCreationSchema_StringField{}, }, diff --git a/pkg/connector/users.go b/pkg/connector/users.go index 06fe798e..41e7cb8a 100644 --- a/pkg/connector/users.go +++ b/pkg/connector/users.go @@ -141,10 +141,6 @@ func extractProfileFields(accountInfo *v2.AccountInfo, createReq *snowflake.Crea if disabledVal, ok := pMap["disabled"].(bool); ok { createReq.Disabled = disabledVal } - // Allow must_change_password to be overridden from profile (as boolean) - if mustChangeVal, ok := pMap["must_change_password"].(bool); ok { - createReq.MustChangePassword = mustChangeVal - } // Default warehouse, namespace, role, and secondary roles if defaultWarehouseStr, ok := pMap["default_warehouse"].(string); ok && defaultWarehouseStr != "" { createReq.DefaultWarehouse = defaultWarehouseStr @@ -214,7 +210,6 @@ func (o *userBuilder) CreateAccountCapabilityDetails(ctx context.Context) (*v2.C return &v2.CredentialDetailsAccountProvisioning{ SupportedCredentialOptions: []v2.CapabilityDetailCredentialOption{ v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_RANDOM_PASSWORD, - v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_ENCRYPTED_PASSWORD, }, PreferredCredentialOption: v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_RANDOM_PASSWORD, }, nil, nil @@ -259,6 +254,7 @@ func (o *userBuilder) CreateAccount( // Handle password generation var plaintextData []*v2.PlaintextData if credentialOptions != nil { + createReq.MustChangePassword = credentialOptions.GetForceChangeAtNextLogin() // Generate password if random password is requested if randomPassword := credentialOptions.GetRandomPassword(); randomPassword != nil { password, err := crypto.GeneratePassword(ctx, credentialOptions) From 04b417dccbf051f84435068ad1c895738c763b3a Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Mon, 16 Feb 2026 15:03:03 -0300 Subject: [PATCH 08/12] address pr comments --- pkg/connector/users.go | 36 +++++++++++++++++------------------ pkg/snowflake/account_role.go | 16 ++++++++-------- pkg/snowflake/database.go | 8 ++++---- pkg/snowflake/user.go | 8 ++++---- pkg/snowflake/user_rest.go | 17 +++++++---------- 5 files changed, 41 insertions(+), 44 deletions(-) diff --git a/pkg/connector/users.go b/pkg/connector/users.go index 41e7cb8a..3126680c 100644 --- a/pkg/connector/users.go +++ b/pkg/connector/users.go @@ -2,6 +2,7 @@ package connector import ( "context" + "errors" "fmt" "net/http" "strings" @@ -210,6 +211,7 @@ func (o *userBuilder) CreateAccountCapabilityDetails(ctx context.Context) (*v2.C return &v2.CredentialDetailsAccountProvisioning{ SupportedCredentialOptions: []v2.CapabilityDetailCredentialOption{ v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_RANDOM_PASSWORD, + v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_ENCRYPTED_PASSWORD, }, PreferredCredentialOption: v2.CapabilityDetailCredentialOption_CAPABILITY_DETAIL_CREDENTIAL_OPTION_RANDOM_PASSWORD, }, nil, nil @@ -256,23 +258,19 @@ func (o *userBuilder) CreateAccount( if credentialOptions != nil { createReq.MustChangePassword = credentialOptions.GetForceChangeAtNextLogin() // Generate password if random password is requested - if randomPassword := credentialOptions.GetRandomPassword(); randomPassword != nil { - password, err := crypto.GeneratePassword(ctx, credentialOptions) - if err != nil { - return nil, nil, nil, wrapError(err, "failed to generate random password") - } - createReq.Password = password - - // Return the plaintext password so it can be encrypted and returned to the caller - plaintextData = append(plaintextData, v2.PlaintextData_builder{ - Name: "password", - Description: "Generated password for Snowflake user", - Bytes: []byte(password), - }.Build()) - } else if plaintextPassword := credentialOptions.GetPlaintextPassword(); plaintextPassword != nil { - // Use provided plaintext password - createReq.Password = plaintextPassword.GetPlaintextPassword() + if credentialOptions.GetRandomPassword() == nil && credentialOptions.GetPlaintextPassword() == nil { + return nil, nil, nil, errors.New("unsupported credential option") } + plaintextPassword, err := crypto.GeneratePassword(ctx, credentialOptions) + if err != nil { + return nil, nil, nil, wrapError(err, "failed to generate password") + } + createReq.Password = plaintextPassword + plaintextData = append(plaintextData, &v2.PlaintextData{ + Name: "password", + Description: "Password for the user", + Bytes: []byte(plaintextPassword), + }) } // Create user via REST API @@ -400,9 +398,11 @@ func (o *userBuilder) Delete(ctx context.Context, resourceId *v2.ResourceId, par // Quote the username to match the case-sensitive identifier created with quotes // This ensures we delete the exact case-sensitive identifier quotedUserName := fmt.Sprintf("\"%s\"", userName) + options := &snowflake.DeleteUserOptions{ + IfExists: true, + } // Delete user via REST API - // Using default options (ifExists=false) - _, err := o.client.DeleteUserREST(ctx, quotedUserName, nil) + _, err := o.client.DeleteUserREST(ctx, quotedUserName, options) if err != nil { l.Error("failed to delete user", zap.String("user_name", userName), diff --git a/pkg/snowflake/account_role.go b/pkg/snowflake/account_role.go index 93db859e..874e0efd 100644 --- a/pkg/snowflake/account_role.go +++ b/pkg/snowflake/account_role.go @@ -79,8 +79,8 @@ func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) } var response ListAccountRolesRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) - defer closeResponseBody(resp) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { return nil, err } @@ -92,8 +92,8 @@ func (c *Client) ListAccountRoles(ctx context.Context, cursor string, limit int) if err != nil { return nil, err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) - defer closeResponseBody(resp) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { return nil, err } @@ -117,8 +117,8 @@ func (c *Client) ListAccountRoleGrantees(ctx context.Context, roleName string) ( } var response ListAccountRoleGranteesRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) - defer closeResponseBody(resp) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { return nil, err } @@ -127,8 +127,8 @@ func (c *Client) ListAccountRoleGrantees(ctx context.Context, roleName string) ( if err != nil { return nil, err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) - defer closeResponseBody(resp) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { return nil, err } diff --git a/pkg/snowflake/database.go b/pkg/snowflake/database.go index 0daad95b..fdea1f87 100644 --- a/pkg/snowflake/database.go +++ b/pkg/snowflake/database.go @@ -74,8 +74,8 @@ func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([ } var response ListDatabasesRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) - defer closeResponseBody(resp) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { return nil, err } @@ -87,8 +87,8 @@ func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([ if err != nil { return nil, err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) - defer closeResponseBody(resp) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { return nil, err } diff --git a/pkg/snowflake/user.go b/pkg/snowflake/user.go index 0c6e1692..7b20a836 100644 --- a/pkg/snowflake/user.go +++ b/pkg/snowflake/user.go @@ -196,8 +196,8 @@ func (c *Client) ListUsers(ctx context.Context, cursor string, limit int) ([]Use } var response ListUsersRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) - defer closeResponseBody(resp) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { return nil, err } @@ -206,8 +206,8 @@ func (c *Client) ListUsers(ctx context.Context, cursor string, limit int) ([]Use if err != nil { return nil, err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) - defer closeResponseBody(resp) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { return nil, err } diff --git a/pkg/snowflake/user_rest.go b/pkg/snowflake/user_rest.go index 146fb41d..bf972c8d 100644 --- a/pkg/snowflake/user_rest.go +++ b/pkg/snowflake/user_rest.go @@ -132,19 +132,16 @@ func (c *Client) doRequest( } }() if err != nil { - return nil, &rateLimitData, 0, fmt.Errorf("baton-snowflake: request failed: %w", err) - } - - statusCode := response.StatusCode - if statusCode >= 300 { - // Try to extract error message from response - if errorResponse.Code != "" || errorResponse.ErrMsg != "" { - return &response.Header, &rateLimitData, statusCode, fmt.Errorf("baton-snowflake: snowflake API error: %s - %s", errorResponse.Code, errorResponse.Message()) + statusCode := 0 + if response != nil { + statusCode = response.StatusCode } - return &response.Header, &rateLimitData, statusCode, fmt.Errorf("baton-snowflake: unexpected status code %d", statusCode) + return nil, &rateLimitData, statusCode, fmt.Errorf("baton-snowflake: request failed: %w", err) } - return &response.Header, &rateLimitData, statusCode, nil + // WithErrorResponse ensures c.Do() returns an error for status >= 300, + // so if we reach here, statusCode is guaranteed to be < 300 + return &response.Header, &rateLimitData, response.StatusCode, nil } // CreateUserREST creates a new Snowflake user using the REST API. From 7ff122d39569336fca7033dad45c6143a55c23cf Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Mon, 16 Feb 2026 15:31:43 -0300 Subject: [PATCH 09/12] refactor after rebase --- pkg/connector/databases.go | 4 +-- pkg/connector/tables.go | 16 ++++----- pkg/snowflake/account_role.go | 16 +++++---- pkg/snowflake/database.go | 19 +++++----- pkg/snowflake/helper.go | 4 +-- pkg/snowflake/table.go | 66 +++++++++++++++-------------------- 6 files changed, 61 insertions(+), 64 deletions(-) diff --git a/pkg/connector/databases.go b/pkg/connector/databases.go index e6b6bed4..ec5d2d6f 100644 --- a/pkg/connector/databases.go +++ b/pkg/connector/databases.go @@ -111,9 +111,9 @@ func (o *databaseBuilder) Grants(ctx context.Context, resource *v2.Resource, _ r return nil, nil, nil } - owner, ownerResp, err := o.client.GetAccountRole(ctx, database.Owner) + owner, ownerStatusCode, err := o.client.GetAccountRole(ctx, database.Owner) if err != nil { - if snowflake.IsUnprocessableEntity(ownerResp, err) { + if snowflake.IsUnprocessableEntity(ownerStatusCode, err) { wrappedErr := fmt.Errorf("baton-snowflake: insufficient privileges for database owner role %q (database %q): %w", database.Owner, resource.Id.Resource, err) return nil, nil, status.Error(codes.PermissionDenied, wrappedErr.Error()) } diff --git a/pkg/connector/tables.go b/pkg/connector/tables.go index 95333208..e1b39e5d 100644 --- a/pkg/connector/tables.go +++ b/pkg/connector/tables.go @@ -69,8 +69,8 @@ func (o *tableBuilder) isDBSharedOrSystem(ctx context.Context, resource *v2.Reso return val == "true" || val == "1", nil } } - db, resp, err := o.client.GetDatabase(ctx, databaseName) - if snowflake.IsUnprocessableEntity(resp, err) { + db, statusCode, err := o.client.GetDatabase(ctx, databaseName) + if snowflake.IsUnprocessableEntity(statusCode, err) { return true, nil } if err != nil { @@ -148,7 +148,7 @@ func (o *tableBuilder) List(ctx context.Context, parentResourceID *v2.ResourceId } const accountPageSize = 200 - tables, nextCursor, _, err := o.client.ListTablesInAccount(ctx, cursor, accountPageSize) + tables, nextCursor, err := o.client.ListTablesInAccount(ctx, cursor, accountPageSize) if err != nil { return nil, nil, wrapError(err, "failed to list tables in account") } @@ -280,9 +280,9 @@ func (o *tableBuilder) Grants(ctx context.Context, resource *v2.Resource, opts r switch tg.GrantedTo { case grantedToRole: - role, resp, err := o.client.GetAccountRole(ctx, tg.GranteeName) + role, statusCode, err := o.client.GetAccountRole(ctx, tg.GranteeName) if err != nil { - if snowflake.IsUnprocessableEntity(resp, err) { + if snowflake.IsUnprocessableEntity(statusCode, err) { principalId, idErr := rs.NewResourceID(accountRoleResourceType, tg.GranteeName) if idErr != nil { continue @@ -335,14 +335,14 @@ func (o *tableBuilder) Grants(ctx context.Context, resource *v2.Resource, opts r } if ownerPrincipalID == nil { - table, _, err := o.client.GetTable(ctx, databaseName, schemaName, tableName) + table, err := o.client.GetTable(ctx, databaseName, schemaName, tableName) if err != nil { return nil, nil, wrapError(err, "failed to get table for owner fallback") } if table != nil && table.Owner != "" && table.Owner != "SNOWFLAKE" { - owner, ownerResp, err := o.client.GetAccountRole(ctx, table.Owner) + owner, ownerStatusCode, err := o.client.GetAccountRole(ctx, table.Owner) switch { - case snowflake.IsUnprocessableEntity(ownerResp, err): + case snowflake.IsUnprocessableEntity(ownerStatusCode, err): // system role, skip case err != nil: return nil, nil, wrapError(err, fmt.Sprintf("failed to get account role for table owner %q", table.Owner)) diff --git a/pkg/snowflake/account_role.go b/pkg/snowflake/account_role.go index 874e0efd..4ded711f 100644 --- a/pkg/snowflake/account_role.go +++ b/pkg/snowflake/account_role.go @@ -138,33 +138,37 @@ func (c *Client) ListAccountRoleGrantees(ctx context.Context, roleName string) ( return accountRoleGrantees, nil } -func (c *Client) GetAccountRole(ctx context.Context, roleName string) (*AccountRole, error) { +func (c *Client) GetAccountRole(ctx context.Context, roleName string) (*AccountRole, int, error) { queries := []string{ fmt.Sprintf("SHOW ROLES LIKE '%s' LIMIT 1;", roleName), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, err + return nil, 0, err } var response ListAccountRolesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) defer closeResponseBody(resp) if err != nil { - return nil, err + statusCode := 0 + if resp != nil { + statusCode = resp.StatusCode + } + return nil, statusCode, err } accountRoles, err := response.GetAccountRoles() if err != nil { - return nil, err + return nil, resp.StatusCode, err } if len(accountRoles) == 0 { - return nil, nil + return nil, resp.StatusCode, nil } - return &accountRoles[0], nil + return &accountRoles[0], resp.StatusCode, nil } func (c *Client) GrantAccountRole(ctx context.Context, roleName, userName string) error { diff --git a/pkg/snowflake/database.go b/pkg/snowflake/database.go index fdea1f87..46356e55 100644 --- a/pkg/snowflake/database.go +++ b/pkg/snowflake/database.go @@ -101,36 +101,37 @@ func (c *Client) ListDatabases(ctx context.Context, cursor string, limit int) ([ return dbs, nil } -func (c *Client) GetDatabase(ctx context.Context, name string) (*Database, error) { +func (c *Client) GetDatabase(ctx context.Context, name string) (*Database, int, error) { queries := []string{ fmt.Sprintf("SHOW DATABASES LIKE '%s' LIMIT 1;", name), } req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, err + return nil, 0, err } var response ListDatabasesRawResponse resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) defer closeResponseBody(resp) if err != nil { - if IsUnprocessableEntity(resp, err) { - return nil, resp, nil + statusCode := 0 + if resp != nil { + statusCode = resp.StatusCode } - return nil, resp, err + return nil, statusCode, err } databases, err := response.GetDatabases() if err != nil { - return nil, err + return nil, resp.StatusCode, err } if len(databases) == 0 { - return nil, fmt.Errorf("database with name %s not found", name) + return nil, resp.StatusCode, fmt.Errorf("database with name %s not found", name) } else if len(databases) > 1 { - return nil, fmt.Errorf("expected 1 database with name %s, got %d", name, len(databases)) + return nil, resp.StatusCode, fmt.Errorf("expected 1 database with name %s, got %d", name, len(databases)) } - return &databases[0], nil + return &databases[0], resp.StatusCode, nil } diff --git a/pkg/snowflake/helper.go b/pkg/snowflake/helper.go index 34f95339..5aa2253b 100644 --- a/pkg/snowflake/helper.go +++ b/pkg/snowflake/helper.go @@ -8,8 +8,8 @@ import ( // IsUnprocessableEntity reports whether the Snowflake API returned HTTP 422 (Unprocessable Entity). // Snowflake returns 422 for certain operations on system/predefined objects (e.g. SHOW GRANTS OF ROLE for ACCOUNTADMIN, // SHOW ROLES LIKE for some roles). Callers can treat this as "no data" or "not resolvable" instead of a hard error. -func IsUnprocessableEntity(resp *http.Response, err error) bool { - if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { +func IsUnprocessableEntity(statusCode int, err error) bool { + if statusCode == http.StatusUnprocessableEntity { return true } if err != nil && (strings.Contains(err.Error(), "422") || strings.Contains(err.Error(), "Unprocessable Entity")) { diff --git a/pkg/snowflake/table.go b/pkg/snowflake/table.go index 7bef8f19..ebeba9b0 100644 --- a/pkg/snowflake/table.go +++ b/pkg/snowflake/table.go @@ -61,7 +61,7 @@ func (r *ListTablesRawResponse) ListTables() ([]Table, error) { const tableListCursorSep = "\x00" -func (c *Client) ListTablesInAccount(ctx context.Context, cursor string, limit int) ([]Table, string, *http.Response, error) { +func (c *Client) ListTablesInAccount(ctx context.Context, cursor string, limit int) ([]Table, string, error) { l := ctxzap.Extract(ctx) var q string @@ -81,43 +81,39 @@ func (c *Client) ListTablesInAccount(ctx context.Context, cursor string, limit i req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, "", nil, err + return nil, "", err } var response ListTablesRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { - if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { + if resp1 != nil && resp1.StatusCode == http.StatusUnprocessableEntity { l.Debug("Insufficient privileges for SHOW TABLES IN ACCOUNT") wrappedErr := fmt.Errorf("baton-snowflake: insufficient privileges for SHOW TABLES IN ACCOUNT: %w", err) - return nil, "", nil, status.Error(codes.PermissionDenied, wrappedErr.Error()) + return nil, "", status.Error(codes.PermissionDenied, wrappedErr.Error()) } - return nil, "", nil, err - } - if resp != nil { - defer resp.Body.Close() + return nil, "", err } req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, "", resp, err + return nil, "", err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { - if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { + if resp2 != nil && resp2.StatusCode == http.StatusUnprocessableEntity { l.Debug("Insufficient privileges for SHOW TABLES IN ACCOUNT (statement result)") wrappedErr := fmt.Errorf("baton-snowflake: insufficient privileges for SHOW TABLES IN ACCOUNT (statement result): %w", err) - return nil, "", nil, status.Error(codes.PermissionDenied, wrappedErr.Error()) + return nil, "", status.Error(codes.PermissionDenied, wrappedErr.Error()) } - return nil, "", resp, err - } - if resp != nil { - defer resp.Body.Close() + return nil, "", err } tables, err := response.ListTables() if err != nil { - return nil, "", resp, err + return nil, "", err } var nextCursor string @@ -125,7 +121,7 @@ func (c *Client) ListTablesInAccount(ctx context.Context, cursor string, limit i last := tables[len(tables)-1] nextCursor = last.DatabaseName + tableListCursorSep + last.SchemaName + tableListCursorSep + last.Name } - return tables, nextCursor, resp, nil + return tables, nextCursor, nil } // escapeSingleQuote doubles single quotes for use inside SQL string literals. @@ -149,7 +145,7 @@ func escapeDoubleQuotedIdentifier(s string) string { return strings.ReplaceAll(s, `"`, `""`) } -func (c *Client) GetTable(ctx context.Context, database, schema, tableName string) (*Table, *http.Response, error) { +func (c *Client) GetTable(ctx context.Context, database, schema, tableName string) (*Table, error) { likePattern := escapeLikePattern(tableName) queries := []string{ fmt.Sprintf("SHOW TABLES LIKE '%s' ESCAPE '\\' IN SCHEMA \"%s\".\"%s\" LIMIT 1;", likePattern, escapeDoubleQuotedIdentifier(database), escapeDoubleQuotedIdentifier(schema)), @@ -157,46 +153,42 @@ func (c *Client) GetTable(ctx context.Context, database, schema, tableName strin req, err := c.PostStatementRequest(ctx, queries) if err != nil { - return nil, nil, err + return nil, err } var response ListTablesRawResponse - resp, err := c.Do(req, uhttp.WithJSONResponse(&response)) + resp1, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp1) if err != nil { - if resp != nil && resp.StatusCode == http.StatusUnprocessableEntity { - return nil, resp, nil + if resp1 != nil && resp1.StatusCode == http.StatusUnprocessableEntity { + return nil, nil } - return nil, nil, err - } - if resp != nil { - defer resp.Body.Close() + return nil, err } req, err = c.GetStatementResponse(ctx, response.StatementHandle) if err != nil { - return nil, resp, err + return nil, err } - resp, err = c.Do(req, uhttp.WithJSONResponse(&response)) + resp2, err := c.Do(req, uhttp.WithJSONResponse(&response)) + defer closeResponseBody(resp2) if err != nil { - return nil, resp, err - } - if resp != nil { - defer resp.Body.Close() + return nil, err } tables, err := response.ListTables() if err != nil { - return nil, resp, err + return nil, err } // Filter by exact match (database, schema, and name) for _, table := range tables { if table.DatabaseName == database && table.SchemaName == schema && table.Name == tableName { - return &table, resp, nil + return &table, nil } } - return nil, resp, fmt.Errorf("table %s.%s.%s not found", database, schema, tableName) + return nil, fmt.Errorf("table %s.%s.%s not found", database, schema, tableName) } var tableGrantStructFieldToColumnMap = map[string]string{ From e710dfa8a55e596193c420c9998f9f588f52404a Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Tue, 17 Feb 2026 14:23:53 -0300 Subject: [PATCH 10/12] change log to Debug --- pkg/snowflake/user_rest.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/snowflake/user_rest.go b/pkg/snowflake/user_rest.go index bf972c8d..b0fbabf1 100644 --- a/pkg/snowflake/user_rest.go +++ b/pkg/snowflake/user_rest.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "log" "net/http" "net/url" @@ -93,6 +92,7 @@ func (c *Client) doRequest( body interface{}, opts ...uhttp.RequestOption, ) (*http.Header, *v2.RateLimitDescription, int, error) { + l := ctxzap.Extract(ctx) var requestOptions []uhttp.RequestOption requestOptions = append(requestOptions, uhttp.WithAcceptJSONHeader(), @@ -128,7 +128,7 @@ func (c *Client) doRequest( _, _ = io.Copy(io.Discard, response.Body) closeErr := response.Body.Close() if closeErr != nil { - log.Printf("baton-snowflake: warning: failed to close response body: %v", closeErr) + l.Debug("baton-snowflake: warning: failed to close response body", zap.Error(closeErr)) } }() if err != nil { From 5a0394df7439ae4867494d13b931270f8f2aa72c Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Tue, 17 Feb 2026 14:46:20 -0300 Subject: [PATCH 11/12] fix lint --- pkg/snowflake/user_rest.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/snowflake/user_rest.go b/pkg/snowflake/user_rest.go index b0fbabf1..b38fd25f 100644 --- a/pkg/snowflake/user_rest.go +++ b/pkg/snowflake/user_rest.go @@ -22,7 +22,7 @@ type CreateUserRequest struct { LastName string `json:"lastName,omitempty"` Email string `json:"email,omitempty"` Comment string `json:"comment,omitempty"` - Password string `json:"password,omitempty"` + Password string `json:"password,omitempty"` // #nosec G117: used for Snowflake API request body MustChangePassword bool `json:"mustChangePassword,omitempty"` Disabled bool `json:"disabled,omitempty"` DefaultWarehouse string `json:"defaultWarehouse,omitempty"` From cc2a8505913328648623e28f34807e5bcb9fa2b3 Mon Sep 17 00:00:00 2001 From: agustin-conductor Date: Wed, 18 Feb 2026 14:40:27 -0300 Subject: [PATCH 12/12] address quotation mark in username edge case --- pkg/connector/helpers.go | 15 ++++++++++++++- pkg/connector/users.go | 14 ++++++-------- pkg/snowflake/user.go | 4 +++- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/pkg/connector/helpers.go b/pkg/connector/helpers.go index 77f2b4d2..52f5fa50 100644 --- a/pkg/connector/helpers.go +++ b/pkg/connector/helpers.go @@ -1,9 +1,22 @@ package connector -import "fmt" +import ( + "fmt" + "strings" +) func wrapError(err error, message string) error { return fmt.Errorf("snowflake-connector: %s: %w", message, err) } const resourcePageSize = 50 + +// quoteSnowflakeIdentifier properly escapes and quotes a Snowflake identifier. +// In Snowflake, double quotes inside identifiers must be escaped by doubling them. +// Example: o"donnel becomes "o""donnel". +func quoteSnowflakeIdentifier(identifier string) string { + // Escape double quotes by doubling them + escaped := strings.ReplaceAll(identifier, `"`, `""`) + // Wrap in double quotes + return fmt.Sprintf(`"%s"`, escaped) +} diff --git a/pkg/connector/users.go b/pkg/connector/users.go index 3126680c..efadd082 100644 --- a/pkg/connector/users.go +++ b/pkg/connector/users.go @@ -226,26 +226,23 @@ func (o *userBuilder) CreateAccount( l := ctxzap.Extract(ctx) // Extract user name from accountInfo - // The user name should come from profile.name first (required in schema), then fall back to Login + // The user name must be provided in profile.name (required in schema) userName := "" if profile := accountInfo.GetProfile(); profile != nil { if nameStr, ok := rs.GetProfileStringValue(profile, "name"); ok && nameStr != "" { userName = nameStr } } - // Fall back to login if profile name is not available - if userName == "" { - userName = accountInfo.GetLogin() - } if userName == "" { - return nil, nil, nil, status.Error(codes.InvalidArgument, "baton-snowflake: user name is required (provide via profile.name or login)") + return nil, nil, nil, status.Error(codes.InvalidArgument, "baton-snowflake: user name is required (provide via profile.name)") } // Build create user request // name is the only required field for the create user request // Quote the username to preserve case sensitivity (Snowflake stores unquoted identifiers in uppercase) - quotedUserName := fmt.Sprintf("\"%s\"", userName) + // Escape any double quotes in the username by doubling them + quotedUserName := quoteSnowflakeIdentifier(userName) createReq := &snowflake.CreateUserRequest{ Name: quotedUserName, } @@ -397,7 +394,8 @@ func (o *userBuilder) Delete(ctx context.Context, resourceId *v2.ResourceId, par // Quote the username to match the case-sensitive identifier created with quotes // This ensures we delete the exact case-sensitive identifier - quotedUserName := fmt.Sprintf("\"%s\"", userName) + // Escape any double quotes in the username by doubling them + quotedUserName := quoteSnowflakeIdentifier(userName) options := &snowflake.DeleteUserOptions{ IfExists: true, } diff --git a/pkg/snowflake/user.go b/pkg/snowflake/user.go index 7b20a836..adca85ed 100644 --- a/pkg/snowflake/user.go +++ b/pkg/snowflake/user.go @@ -221,8 +221,10 @@ func (c *Client) ListUsers(ctx context.Context, cursor string, limit int) ([]Use } func (c *Client) GetUser(ctx context.Context, username string) (*User, int, error) { + // Escape double quotes in username by doubling them before quoting + escapedUsername := escapeDoubleQuotedIdentifier(username) queries := []string{ - fmt.Sprintf("DESCRIBE USER \"%s\";", username), + fmt.Sprintf("DESCRIBE USER \"%s\";", escapedUsername), } req, err := c.PostStatementRequest(ctx, queries)