diff --git a/do.go b/do.go index af9dde27..087a36ff 100644 --- a/do.go +++ b/do.go @@ -172,7 +172,7 @@ func (d *DO) WithContext(ctx context.Context) Dao { return d.getInstance(d.db.Wi // Clauses specify Clauses func (d *DO) Clauses(conds ...clause.Expression) Dao { - if err := checkConds(conds); err != nil { + if err := checkCondsWithChecker(conds, d.ClauseChecker); err != nil { newDB := d.db.Session(new(gorm.Session)) _ = newDB.AddError(err) return d.getInstance(newDB) diff --git a/do_clause_checker_test.go b/do_clause_checker_test.go new file mode 100644 index 00000000..4af1dc20 --- /dev/null +++ b/do_clause_checker_test.go @@ -0,0 +1,41 @@ +package gen + +import ( + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/utils/tests" + "gorm.io/plugin/dbresolver" +) + +func TestDOClausesWithClauseChecker(t *testing.T) { + base, err := gorm.Open(tests.DummyDialector{}, &gorm.Config{}) + if err != nil { + t.Fatalf("open db: %v", err) + } + + dry1 := base.Session(&gorm.Session{DryRun: true, NewDB: true}) + var d1 DO + d1.UseDB(dry1) + dao1 := d1.Clauses(dbresolver.Use("db_1")).(*DO) + if dao1.db.Error == nil { + t.Fatalf("expected error, got nil") + } + + dry2 := base.Session(&gorm.Session{DryRun: true, NewDB: true}) + var d2 DO + d2.UseDB(dry2, WithClauseChecker(func(clause.Expression) error { return ErrClauseNotHandled })) + dao2 := d2.Clauses(dbresolver.Use("db_1")).(*DO) + if dao2.db.Error == nil { + t.Fatalf("expected error, got nil") + } + + dry3 := base.Session(&gorm.Session{DryRun: true, NewDB: true}) + var d3 DO + d3.UseDB(dry3, WithClauseChecker(func(clause.Expression) error { return nil })) + dao3 := d3.Clauses(dbresolver.Use("db_1")).(*DO) + if dao3.db.Error != nil { + t.Fatalf("unexpected error: %v", dao3.db.Error) + } +} diff --git a/do_options.go b/do_options.go index 98bcaa61..8d667313 100644 --- a/do_options.go +++ b/do_options.go @@ -1,12 +1,23 @@ package gen +import ( + "errors" + + "gorm.io/gorm/clause" +) + // DOOption gorm option interface type DOOption interface { Apply(*DOConfig) error AfterInitialize(*DO) error } +type ClauseChecker func(clause.Expression) error + +var ErrClauseNotHandled = errors.New("clause not handled") + type DOConfig struct { + ClauseChecker ClauseChecker } // Apply update config to new config @@ -21,3 +32,18 @@ func (c *DOConfig) Apply(config *DOConfig) error { func (c *DOConfig) AfterInitialize(db *DO) error { return nil } + +type clauseCheckerOption struct { + checker ClauseChecker +} + +func (o clauseCheckerOption) Apply(cfg *DOConfig) error { + cfg.ClauseChecker = o.checker + return nil +} + +func (clauseCheckerOption) AfterInitialize(*DO) error { return nil } + +func WithClauseChecker(checker ClauseChecker) DOOption { + return clauseCheckerOption{checker: checker} +} diff --git a/sec_check.go b/sec_check.go index 10fe693c..6cbb6da5 100644 --- a/sec_check.go +++ b/sec_check.go @@ -19,6 +19,22 @@ func checkConds(conds []clause.Expression) error { return nil } +func checkCondsWithChecker(conds []clause.Expression, checker ClauseChecker) error { + for _, cond := range conds { + if checker != nil { + if err := checker(cond); err == nil { + continue + } else if err != ErrClauseNotHandled { + return err + } + } + if err := CheckClause(cond); err != nil { + return err + } + } + return nil +} + var banClauses = map[string]bool{ // "INSERT": true, "VALUES": true,