diff --git a/examples/postgres-test.yml b/examples/postgres-test.yml index d9cd236f..1ea60bfc 100644 --- a/examples/postgres-test.yml +++ b/examples/postgres-test.yml @@ -187,7 +187,7 @@ resource_types: queries: # Update the user with generated password (PostgreSQL style) - | - UPDATE users SET password_hash = crypt(?, gen_salt('bf')) WHERE username = ? + UPDATE users SET password_hash = crypt(?, gen_salt('bf')) WHERE username = ? # Configuration for "role" resources role: diff --git a/pkg/bsql/config.go b/pkg/bsql/config.go index b580b77d..b173f5ac 100644 --- a/pkg/bsql/config.go +++ b/pkg/bsql/config.go @@ -13,6 +13,10 @@ import ( v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" ) +type staticValidator interface { + staticValidate(ctx context.Context, s *SQLSyncer) error +} + // Config represents the overall connector configuration. type Config struct { // AppName is the application name that identifies the connector. diff --git a/pkg/bsql/query.go b/pkg/bsql/query.go index 4a6358b4..66ccf7f0 100644 --- a/pkg/bsql/query.go +++ b/pkg/bsql/query.go @@ -30,6 +30,12 @@ type executor interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) } +var identSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_]+`) + +func SanitizeIdentifier(s string) string { + return identSanitizer.ReplaceAllString(s, "") +} + type paginationContext struct { Strategy string Limit int64 @@ -95,6 +101,20 @@ func parseToken(token string) (*queryTokenOpts, error) { return opts, nil } +func (s *SQLSyncer) queryVars(query string) ([]string, error) { + result := make([]string, 0) + + for _, token := range queryOptRegex.FindAllString(query, -1) { + opts, err := parseToken(token) + if err != nil { + return nil, err + } + result = append(result, opts.Key) + } + + return result, nil +} + func (s *SQLSyncer) parseQueryOpts(pCtx *paginationContext, query string, vars map[string]any) (string, []interface{}, bool, error) { if vars == nil { vars = make(map[string]any) @@ -135,7 +155,7 @@ func (s *SQLSyncer) parseQueryOpts(pCtx *paginationContext, query string, vars m // If the value is unquoted, directly insert the value as a string if opts.Unquoted { - return fmt.Sprintf("%v", val) + return SanitizeIdentifier(fmt.Sprintf("%v", val)) } qArgs = append(qArgs, val) @@ -280,7 +300,7 @@ func (s *SQLSyncer) prepareProvisioningQuery(query string, vars map[string]any) } if opts.Unquoted { - return fmt.Sprintf("%v", v) + return SanitizeIdentifier(fmt.Sprintf("%v", v)) } qArgs = append(qArgs, v) diff --git a/pkg/bsql/query_test.go b/pkg/bsql/query_test.go index 4fcae4a7..6e2a3bd2 100644 --- a/pkg/bsql/query_test.go +++ b/pkg/bsql/query_test.go @@ -358,6 +358,23 @@ func Test_parseQueryOpts(t *testing.T) { false, false, }, + { + "Test sql injection attempt with unquoted table name var substitution", + database.MySQL, + args{ + t.Context(), + "SELECT * FROM ? WHERE test = ?", + nil, + map[string]any{ + "table_name": `example_table; DROP TABLE users; --`, + "foo": "test example", + }, + }, + "SELECT * FROM example_tableDROPTABLEusers WHERE test = ?", + []interface{}{"test example"}, + false, + false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/bsql/sql_syncer.go b/pkg/bsql/sql_syncer.go index e907243e..dae8011f 100644 --- a/pkg/bsql/sql_syncer.go +++ b/pkg/bsql/sql_syncer.go @@ -3,6 +3,7 @@ package bsql import ( "context" "database/sql" + "fmt" v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" "github.com/conductorone/baton-sdk/pkg/connectorbuilder" @@ -68,3 +69,75 @@ func NewActionSyncer(ctx context.Context, db *sql.DB, dbEngine database.DbEngine fullConfig: fullConfig, }, nil } + +func (s *SQLSyncer) validateInternal(ctx context.Context, validator staticValidator) error { + if validator == nil { + return nil + } + + err := validator.staticValidate(ctx, s) + if err != nil { + return err + } + + return nil +} + +func (s *SQLSyncer) validateFormatErr(field string, err error) error { + if s.resourceType == nil { + return fmt.Errorf("validation error for action config, field %q: %w", field, err) + } + + return fmt.Errorf("validation error for resource type %q, field %q: %w", s.resourceType.Id, field, err) +} + +func (s *SQLSyncer) Validate(ctx context.Context) error { + if s.fullConfig.Actions != nil { + for key, action := range s.fullConfig.Actions { + err := s.validateInternal(ctx, &action) + if err != nil { + return s.validateFormatErr(fmt.Sprintf("Action[%s]", key), err) + } + } + } + + if err := s.validateInternal(ctx, s.config.List); err != nil { + return s.validateFormatErr("list", err) + } + + if s.config.Entitlements != nil { + if err := s.validateInternal(ctx, s.config.Entitlements); err != nil { + return s.validateFormatErr("entitlements", err) + } + } + + if s.config.StaticEntitlements != nil { + for _, entitlement := range s.config.StaticEntitlements { + if err := s.validateInternal(ctx, entitlement); err != nil { + return s.validateFormatErr("static_entitlements", err) + } + } + } + + if s.config.Grants != nil { + for _, grant := range s.config.Grants { + if err := s.validateInternal(ctx, grant); err != nil { + return s.validateFormatErr("grants", err) + } + } + } + + if s.config.AccountProvisioning != nil { + if err := s.validateInternal(ctx, s.config.AccountProvisioning); err != nil { + return s.validateFormatErr("account_provisioning", err) + } + } + + if s.config.CredentialRotation != nil { + if err := s.validateInternal(ctx, s.config.CredentialRotation); err != nil { + return s.validateFormatErr("credential_rotation", err) + } + } + + return nil +} diff --git a/pkg/bsql/validate.go b/pkg/bsql/validate.go new file mode 100644 index 00000000..42572181 --- /dev/null +++ b/pkg/bsql/validate.go @@ -0,0 +1,169 @@ +package bsql + +import ( + "context" + "errors" + "fmt" +) + +func validateVarsInQuery(s *SQLSyncer, query string, vars map[string]string) error { + if query == "" { + return fmt.Errorf("query is required") + } + + usedVars, err := s.queryVars(query) + if err != nil { + return fmt.Errorf("failed to parse query for vars: %w", err) + } + + if vars == nil { + vars = make(map[string]string) + } + + for _, v := range usedVars { + if _, ok := vars[v]; !ok { + if v == limitKey || v == offsetKey || v == cursorKey { + continue + } + return fmt.Errorf("query uses variable '%s' which is not defined in vars", v) + } + } + + return nil +} + +func (l *ListQuery) staticValidate(ctx context.Context, s *SQLSyncer) error { + return validateVarsInQuery(s, l.Query, l.Vars) +} + +func (l *EntitlementsQuery) staticValidate(ctx context.Context, s *SQLSyncer) error { + for _, mapping := range l.Map { + if mapping.Provisioning == nil { + continue + } + + if mapping.Provisioning.Grant != nil { + for _, query := range mapping.Provisioning.Grant.Queries { + err := validateVarsInQuery(s, query, mapping.Provisioning.Vars) + if err != nil { + return err + } + } + } + + if mapping.Provisioning.Revoke != nil { + for _, query := range mapping.Provisioning.Revoke.Queries { + err := validateVarsInQuery(s, query, mapping.Provisioning.Vars) + if err != nil { + return err + } + } + } + } + + return validateVarsInQuery(s, l.Query, l.Vars) +} + +func (l *EntitlementMapping) staticValidate(ctx context.Context, s *SQLSyncer) error { + if l.Provisioning == nil { + return nil + } + + if l.Provisioning.Grant != nil { + for _, query := range l.Provisioning.Grant.Queries { + err := validateVarsInQuery(s, query, l.Provisioning.Vars) + if err != nil { + return err + } + } + } + + if l.Provisioning.Revoke != nil { + for _, query := range l.Provisioning.Revoke.Queries { + err := validateVarsInQuery(s, query, l.Provisioning.Vars) + if err != nil { + return err + } + } + } + + return nil +} + +func (l *GrantsQuery) staticValidate(ctx context.Context, s *SQLSyncer) error { + return validateVarsInQuery(s, l.Query, l.Vars) +} + +func (l *AccountProvisioning) staticValidate(ctx context.Context, s *SQLSyncer) error { + if l.Credentials == nil { + return errors.New("no credentials defined") + } + + if l.Credentials.EncryptedPassword == nil && + l.Credentials.RandomPassword == nil && + l.Credentials.NoPassword == nil { + return errors.New("no credential method defined") + } + + if l.Credentials.RandomPassword != nil { + if l.Credentials.RandomPassword.MaxLength <= 0 { + return errors.New("random password max_length must be greater than zero") + } + + if l.Credentials.RandomPassword.MinLength <= 0 { + return errors.New("random password min_length must be greater than zero") + } + + if l.Credentials.RandomPassword.MinLength > l.Credentials.RandomPassword.MaxLength { + return errors.New("random password min_length cannot be greater than max_length") + } + } + + if l.Create == nil { + return errors.New("no create functions defined") + } + + for _, query := range l.Create.Queries { + err := validateVarsInQuery(s, query, l.Create.Vars) + if err != nil { + return err + } + } + + if l.Validate == nil { + return errors.New("no validate functions defined") + } + + err := validateVarsInQuery(s, l.Validate.Query, l.Validate.Vars) + if err != nil { + return err + } + + return nil +} + +func (l *CredentialRotation) staticValidate(ctx context.Context, s *SQLSyncer) error { + if l.Update != nil { + for _, query := range l.Update.Queries { + err := validateVarsInQuery(s, query, l.Update.Vars) + if err != nil { + return err + } + } + } + + return nil +} + +func (l *ActionConfig) staticValidate(ctx context.Context, s *SQLSyncer) error { + availableVars := make(map[string]string) + for k, v := range l.Vars { + availableVars[k] = v + } + + for k, config := range l.Arguments { + availableVars[k] = config.Type + } + + return validateVarsInQuery(s, l.Query, availableVars) +} diff --git a/pkg/bsql/validate_test.go b/pkg/bsql/validate_test.go new file mode 100644 index 00000000..7db3cf40 --- /dev/null +++ b/pkg/bsql/validate_test.go @@ -0,0 +1,51 @@ +package bsql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidate(t *testing.T) { + tcases := []struct { + name string + validator staticValidator + expectErr bool + }{ + { + name: "valid list query", + validator: &ListQuery{ + Query: "SELECT * FROM users WHERE id = ? LIMIT ? OFFSET ?", + Vars: map[string]string{ + "userid": "string", + }, + }, + expectErr: false, + }, + { + name: "invalid list query", + validator: &ListQuery{ + Query: "SELECT * FROM users WHERE id = ? LIMIT ? OFFSET ?", + Vars: map[string]string{ + "userid": "string", + }, + }, + expectErr: true, + }, + } + + for _, tc := range tcases { + t.Run(tc.name, func(t *testing.T) { + ctx := t.Context() + + syncer := &SQLSyncer{} + + err := tc.validator.staticValidate(ctx, syncer) + if tc.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index f078a09a..4a9b6131 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -76,7 +76,23 @@ func (c *Connector) Metadata(ctx context.Context) (*v2.ConnectorMetadata, error) // Validate is called to ensure that the connector is properly configured. It should exercise any API credentials // to be sure that they are valid. func (c *Connector) Validate(ctx context.Context) (annotations.Annotations, error) { - err := c.db.PingContext(ctx) + syncers, err := c.config.GetSQLSyncers(ctx, c.db, c.dbEngine, c.celEnv) + if err != nil { + return nil, err + } + + for _, syncer := range syncers { + if v, ok := syncer.(interface { + Validate(ctx context.Context) error + }); ok { + err := v.Validate(ctx) + if err != nil { + return nil, err + } + } + } + + err = c.db.PingContext(ctx) if err != nil { return nil, err }