Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion do.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (d *DO) WithContext(ctx context.Context) Dao { return d.getInstance(d.db.Wi

// Clauses specify Clauses
func (d *DO) Clauses(conds ...clause.Expression) Dao {
if err := checkConds(conds); err != nil {
if err := checkCondsWithChecker(conds, d.ClauseChecker); err != nil {
newDB := d.db.Session(new(gorm.Session))
_ = newDB.AddError(err)
return d.getInstance(newDB)
Expand Down
41 changes: 41 additions & 0 deletions do_clause_checker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package gen

import (
"testing"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils/tests"
"gorm.io/plugin/dbresolver"
)

func TestDOClausesWithClauseChecker(t *testing.T) {
base, err := gorm.Open(tests.DummyDialector{}, &gorm.Config{})
if err != nil {
t.Fatalf("open db: %v", err)
}

dry1 := base.Session(&gorm.Session{DryRun: true, NewDB: true})
var d1 DO
d1.UseDB(dry1)
dao1 := d1.Clauses(dbresolver.Use("db_1")).(*DO)
if dao1.db.Error == nil {
t.Fatalf("expected error, got nil")
}

dry2 := base.Session(&gorm.Session{DryRun: true, NewDB: true})
var d2 DO
d2.UseDB(dry2, WithClauseChecker(func(clause.Expression) error { return ErrClauseNotHandled }))
dao2 := d2.Clauses(dbresolver.Use("db_1")).(*DO)
if dao2.db.Error == nil {
t.Fatalf("expected error, got nil")
}

dry3 := base.Session(&gorm.Session{DryRun: true, NewDB: true})
var d3 DO
d3.UseDB(dry3, WithClauseChecker(func(clause.Expression) error { return nil }))
dao3 := d3.Clauses(dbresolver.Use("db_1")).(*DO)
if dao3.db.Error != nil {
t.Fatalf("unexpected error: %v", dao3.db.Error)
}
}
26 changes: 26 additions & 0 deletions do_options.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
package gen

import (
"errors"

"gorm.io/gorm/clause"
)

// DOOption gorm option interface
type DOOption interface {
Apply(*DOConfig) error
AfterInitialize(*DO) error
}

type ClauseChecker func(clause.Expression) error

var ErrClauseNotHandled = errors.New("clause not handled")

type DOConfig struct {
ClauseChecker ClauseChecker
}

// Apply update config to new config
Expand All @@ -21,3 +32,18 @@ func (c *DOConfig) Apply(config *DOConfig) error {
func (c *DOConfig) AfterInitialize(db *DO) error {
return nil
}

type clauseCheckerOption struct {
checker ClauseChecker
}

func (o clauseCheckerOption) Apply(cfg *DOConfig) error {
cfg.ClauseChecker = o.checker
return nil
}

func (clauseCheckerOption) AfterInitialize(*DO) error { return nil }

func WithClauseChecker(checker ClauseChecker) DOOption {
return clauseCheckerOption{checker: checker}
}
16 changes: 16 additions & 0 deletions sec_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ func checkConds(conds []clause.Expression) error {
return nil
}

func checkCondsWithChecker(conds []clause.Expression, checker ClauseChecker) error {
for _, cond := range conds {
if checker != nil {
if err := checker(cond); err == nil {
continue
} else if err != ErrClauseNotHandled {
return err
}
}
if err := CheckClause(cond); err != nil {
return err
}
}
return nil
}

var banClauses = map[string]bool{
// "INSERT": true,
"VALUES": true,
Expand Down
Loading