Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/postgres-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ resource_types:
queries:
# Update the user with generated password (PostgreSQL style)
- |
UPDATE users SET password_hash = crypt(?<password>, gen_salt('bf')) WHERE username = ?<resource_id>
UPDATE users SET password_hash = crypt(?<password>, gen_salt('bf')) WHERE username = ?<username>

# Configuration for "role" resources
role:
Expand Down
4 changes: 4 additions & 0 deletions pkg/bsql/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ import (
v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
)

type staticValidator interface {
staticValidate(ctx context.Context, s *SQLSyncer) error
}

// Config represents the overall connector configuration.
type Config struct {
// AppName is the application name that identifies the connector.
Expand Down
24 changes: 22 additions & 2 deletions pkg/bsql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ type executor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}

var identSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_]+`)

func SanitizeIdentifier(s string) string {
return identSanitizer.ReplaceAllString(s, "")
}

type paginationContext struct {
Strategy string
Limit int64
Expand Down Expand Up @@ -95,6 +101,20 @@ func parseToken(token string) (*queryTokenOpts, error) {
return opts, nil
}

func (s *SQLSyncer) queryVars(query string) ([]string, error) {
result := make([]string, 0)

for _, token := range queryOptRegex.FindAllString(query, -1) {
opts, err := parseToken(token)
if err != nil {
return nil, err
}
result = append(result, opts.Key)
}

return result, nil
}

func (s *SQLSyncer) parseQueryOpts(pCtx *paginationContext, query string, vars map[string]any) (string, []interface{}, bool, error) {
if vars == nil {
vars = make(map[string]any)
Expand Down Expand Up @@ -135,7 +155,7 @@ func (s *SQLSyncer) parseQueryOpts(pCtx *paginationContext, query string, vars m

// If the value is unquoted, directly insert the value as a string
if opts.Unquoted {
return fmt.Sprintf("%v", val)
return SanitizeIdentifier(fmt.Sprintf("%v", val))
}

qArgs = append(qArgs, val)
Expand Down Expand Up @@ -280,7 +300,7 @@ func (s *SQLSyncer) prepareProvisioningQuery(query string, vars map[string]any)
}

if opts.Unquoted {
return fmt.Sprintf("%v", v)
return SanitizeIdentifier(fmt.Sprintf("%v", v))
}

qArgs = append(qArgs, v)
Expand Down
17 changes: 17 additions & 0 deletions pkg/bsql/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,23 @@ func Test_parseQueryOpts(t *testing.T) {
false,
false,
},
{
"Test sql injection attempt with unquoted table name var substitution",
database.MySQL,
args{
t.Context(),
"SELECT * FROM ?<table_name|unquoted> WHERE test = ?<foo>",
nil,
map[string]any{
"table_name": `example_table; DROP TABLE users; --`,
"foo": "test example",
},
},
"SELECT * FROM example_tableDROPTABLEusers WHERE test = ?",
[]interface{}{"test example"},
false,
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
73 changes: 73 additions & 0 deletions pkg/bsql/sql_syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package bsql
import (
"context"
"database/sql"
"fmt"

v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
"github.com/conductorone/baton-sdk/pkg/connectorbuilder"
Expand Down Expand Up @@ -68,3 +69,75 @@ func NewActionSyncer(ctx context.Context, db *sql.DB, dbEngine database.DbEngine
fullConfig: fullConfig,
}, nil
}

func (s *SQLSyncer) validateInternal(ctx context.Context, validator staticValidator) error {
if validator == nil {
return nil
}

err := validator.staticValidate(ctx, s)
if err != nil {
return err
}

return nil
}

func (s *SQLSyncer) validateFormatErr(field string, err error) error {
if s.resourceType == nil {
return fmt.Errorf("validation error for action config, field %q: %w", field, err)
}

return fmt.Errorf("validation error for resource type %q, field %q: %w", s.resourceType.Id, field, err)
}

func (s *SQLSyncer) Validate(ctx context.Context) error {
if s.fullConfig.Actions != nil {
for key, action := range s.fullConfig.Actions {
err := s.validateInternal(ctx, &action)
if err != nil {
return s.validateFormatErr(fmt.Sprintf("Action[%s]", key), err)
}
}
}

if err := s.validateInternal(ctx, s.config.List); err != nil {
return s.validateFormatErr("list", err)
}

if s.config.Entitlements != nil {
if err := s.validateInternal(ctx, s.config.Entitlements); err != nil {
return s.validateFormatErr("entitlements", err)
}
}

if s.config.StaticEntitlements != nil {
for _, entitlement := range s.config.StaticEntitlements {
if err := s.validateInternal(ctx, entitlement); err != nil {
return s.validateFormatErr("static_entitlements", err)
}
}
}

if s.config.Grants != nil {
for _, grant := range s.config.Grants {
if err := s.validateInternal(ctx, grant); err != nil {
return s.validateFormatErr("grants", err)
}
}
}

if s.config.AccountProvisioning != nil {
if err := s.validateInternal(ctx, s.config.AccountProvisioning); err != nil {
return s.validateFormatErr("account_provisioning", err)
}
}

if s.config.CredentialRotation != nil {
if err := s.validateInternal(ctx, s.config.CredentialRotation); err != nil {
return s.validateFormatErr("credential_rotation", err)
}
}

return nil
}
169 changes: 169 additions & 0 deletions pkg/bsql/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package bsql

import (
"context"
"errors"
"fmt"
)

func validateVarsInQuery(s *SQLSyncer, query string, vars map[string]string) error {
if query == "" {
return fmt.Errorf("query is required")
}

usedVars, err := s.queryVars(query)
if err != nil {
return fmt.Errorf("failed to parse query for vars: %w", err)
}

if vars == nil {
vars = make(map[string]string)
}

for _, v := range usedVars {
if _, ok := vars[v]; !ok {
if v == limitKey || v == offsetKey || v == cursorKey {
continue
}
return fmt.Errorf("query uses variable '%s' which is not defined in vars", v)
}
}

return nil
}

func (l *ListQuery) staticValidate(ctx context.Context, s *SQLSyncer) error {
return validateVarsInQuery(s, l.Query, l.Vars)
}

func (l *EntitlementsQuery) staticValidate(ctx context.Context, s *SQLSyncer) error {
for _, mapping := range l.Map {
if mapping.Provisioning == nil {
continue
}

if mapping.Provisioning.Grant != nil {
for _, query := range mapping.Provisioning.Grant.Queries {
err := validateVarsInQuery(s, query, mapping.Provisioning.Vars)
if err != nil {
return err
}
}
}

if mapping.Provisioning.Revoke != nil {
for _, query := range mapping.Provisioning.Revoke.Queries {
err := validateVarsInQuery(s, query, mapping.Provisioning.Vars)
if err != nil {
return err
}
}
}
}

return validateVarsInQuery(s, l.Query, l.Vars)
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing validation of EntitlementsQuery.Map field. The EntitlementsQuery.StaticValidate method only validates the query variables but doesn't validate the Map field, which is a slice of EntitlementMapping objects. Each EntitlementMapping has provisioning queries that should also be validated.

Consider adding validation for the Map field:

func (l *EntitlementsQuery) StaticValidate(ctx context.Context, s *SQLSyncer) error {
	if err := validateVarsInQuery(s, l.Query, l.Vars); err != nil {
		return err
	}
	
	for i, mapping := range l.Map {
		if mapping != nil {
			if err := mapping.StaticValidate(ctx, s); err != nil {
				return fmt.Errorf("map[%d]: %w", i, err)
			}
		}
	}
	
	return nil
}
Suggested change
return validateVarsInQuery(s, l.Query, l.Vars)
if err := validateVarsInQuery(s, l.Query, l.Vars); err != nil {
return err
}
for i, mapping := range l.Map {
if mapping != nil {
if err := mapping.StaticValidate(ctx, s); err != nil {
return fmt.Errorf("map[%d]: %w", i, err)
}
}
}
return nil

Copilot uses AI. Check for mistakes.
}

func (l *EntitlementMapping) staticValidate(ctx context.Context, s *SQLSyncer) error {
if l.Provisioning == nil {
return nil
}

if l.Provisioning.Grant != nil {
for _, query := range l.Provisioning.Grant.Queries {
err := validateVarsInQuery(s, query, l.Provisioning.Vars)
if err != nil {
return err
}
}
}

if l.Provisioning.Revoke != nil {
for _, query := range l.Provisioning.Revoke.Queries {
err := validateVarsInQuery(s, query, l.Provisioning.Vars)
if err != nil {
return err
}
}
}

return nil
}

func (l *GrantsQuery) staticValidate(ctx context.Context, s *SQLSyncer) error {
return validateVarsInQuery(s, l.Query, l.Vars)
}

func (l *AccountProvisioning) staticValidate(ctx context.Context, s *SQLSyncer) error {
if l.Credentials == nil {
return errors.New("no credentials defined")
}

if l.Credentials.EncryptedPassword == nil &&
l.Credentials.RandomPassword == nil &&
l.Credentials.NoPassword == nil {
return errors.New("no credential method defined")
}

if l.Credentials.RandomPassword != nil {
if l.Credentials.RandomPassword.MaxLength <= 0 {
return errors.New("random password max_length must be greater than zero")
}

if l.Credentials.RandomPassword.MinLength <= 0 {
return errors.New("random password min_length must be greater than zero")
}

if l.Credentials.RandomPassword.MinLength > l.Credentials.RandomPassword.MaxLength {
return errors.New("random password min_length cannot be greater than max_length")
}
}

if l.Create == nil {
return errors.New("no create functions defined")
}

for _, query := range l.Create.Queries {
err := validateVarsInQuery(s, query, l.Create.Vars)
if err != nil {
return err
}
}

if l.Validate == nil {
return errors.New("no validate functions defined")
}

err := validateVarsInQuery(s, l.Validate.Query, l.Validate.Vars)
if err != nil {
return err
}

return nil
}

func (l *CredentialRotation) staticValidate(ctx context.Context, s *SQLSyncer) error {
if l.Update != nil {
for _, query := range l.Update.Queries {
err := validateVarsInQuery(s, query, l.Update.Vars)
if err != nil {
return err
}
}
}

return nil
}

func (l *ActionConfig) staticValidate(ctx context.Context, s *SQLSyncer) error {
availableVars := make(map[string]string)
for k, v := range l.Vars {
availableVars[k] = v
}

for k, config := range l.Arguments {
availableVars[k] = config.Type
}

return validateVarsInQuery(s, l.Query, availableVars)
}
Loading