diff --git a/callbacks/query.go b/callbacks/query.go index 548bf7092..0276b03a9 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -57,6 +57,23 @@ func BuildQuerySQL(db *gorm.DB) { } } + if db.Statement.Schema != nil && len(db.Statement.WhereHasConditions) > 0 { + for _, whereHasCondition := range db.Statement.WhereHasConditions { + cl, err := newWhereHas(db, whereHasCondition.IsDoesntHave, whereHasCondition.Relation, whereHasCondition.Conds, db.Statement.Schema) + if err != nil { + _ = db.AddError(err) + + continue + } + + if cl == nil { + continue + } + + db.Statement.AddClause(cl) + } + } + if len(db.Statement.Selects) > 0 { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) for idx, name := range db.Statement.Selects { diff --git a/callbacks/where_has.go b/callbacks/where_has.go new file mode 100644 index 000000000..3ef067eca --- /dev/null +++ b/callbacks/where_has.go @@ -0,0 +1,209 @@ +package callbacks + +import ( + "fmt" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +func whereHasDb(db *gorm.DB) *gorm.DB { + tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + + return tx +} + +var relationHandlers = map[schema.RelationshipType]func(*gorm.DB, *gorm.DB, *schema.Relationship, []interface{}) (*gorm.DB, error){ + schema.Many2Many: existsMany2many, + schema.BelongsTo: existsBelongsTo, + schema.HasMany: existsHasMany, + schema.HasOne: existsHasOne, +} + +func newWhereHas(db *gorm.DB, isDoesntHave bool, relationName string, conds []interface{}, s *schema.Schema) (*clause.Where, error) { + var err error + + rel, ok := s.Relationships.Relations[relationName] + if !ok { + return nil, fmt.Errorf("relation %s not found", relationName) + } + + tx := whereHasDb(db) + + reflectResults := rel.FieldSchema.MakeSlice().Elem() + + firstPrimaryField := "" + otherPrimaryFields := make([]interface{}, 0) + + for i, name := range rel.FieldSchema.PrimaryFieldDBNames { + if i == 0 { + firstPrimaryField = name + } else { + otherPrimaryFields = append(otherPrimaryFields, name) + } + } + + tx = tx.Model(reflectResults.Addr().Interface()).Select(firstPrimaryField, otherPrimaryFields...) + if err = tx.Statement.Parse(tx.Statement.Model); err != nil { + return nil, err + } + + handler, ok := relationHandlers[rel.Type] + if !ok { + return nil, fmt.Errorf("unsupported relation type: %v", rel.Type) + } + + tx, err = handler(db, tx, rel, conds) + if err != nil { + return nil, err + } + + cond := "EXISTS(?)" + if isDoesntHave { + cond = "NOT " + cond + } + + cl := clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{ + SQL: cond, + Vars: []interface{}{tx}, + WithoutParentheses: false, + }, + }, + } + + return &cl, nil +} + +func existsHasOne(mainQuery *gorm.DB, existsQuery *gorm.DB, rel *schema.Relationship, conds []interface{}) (*gorm.DB, error) { + return existsHasMany(mainQuery, existsQuery, rel, conds) +} + +func existsHasMany(mainQuery *gorm.DB, existsQuery *gorm.DB, rel *schema.Relationship, conds []interface{}) (*gorm.DB, error) { + if len(rel.References) < 1 { + return nil, fmt.Errorf("relation %s has no references", rel.Name) + } + + for _, reference := range rel.References { + if reference.PrimaryKey != nil { + existsQuery.Statement.AddClause(clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{Table: existsQuery.Statement.Table, Name: reference.ForeignKey.DBName}, + Value: clause.Column{Table: mainQuery.Statement.Table, Name: reference.PrimaryKey.DBName}, + }, + }, + }) + } else { + existsQuery.Statement.AddClause(clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{Table: existsQuery.Statement.Table, Name: reference.ForeignKey.DBName}, + Value: reference.PrimaryValue, + }, + }, + }) + } + } + + existsQuery = applyConds(existsQuery, conds) + + return existsQuery, nil +} + +func existsBelongsTo(mainQuery *gorm.DB, existsQuery *gorm.DB, rel *schema.Relationship, conds []interface{}) (*gorm.DB, error) { + if len(rel.References) < 1 { + return nil, fmt.Errorf("relation %s has no references", rel.Name) + } + + for _, reference := range rel.References { + existsQuery.Statement.AddClause(clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{Table: existsQuery.Statement.Table, Name: reference.PrimaryKey.DBName}, + Value: clause.Column{Table: mainQuery.Statement.Table, Name: reference.ForeignKey.DBName}, + }, + }, + }) + } + + existsQuery = applyConds(existsQuery, conds) + + return existsQuery, nil +} + +func existsMany2many(mainQuery *gorm.DB, existsQuery *gorm.DB, rel *schema.Relationship, conds []interface{}) (*gorm.DB, error) { + if rel.JoinTable != nil { + var parentTableField *schema.Reference = nil + var primaryTableField *schema.Reference = nil + + for _, reference := range rel.References { + if !reference.OwnPrimaryKey { + parentTableField = reference + } else { + primaryTableField = reference + } + } + + if parentTableField == nil { + return nil, fmt.Errorf("relation %s has no parent table field", rel.Name) + } + + if primaryTableField == nil { + return nil, fmt.Errorf("relation %s has no primary table field", rel.Name) + } + + fromClause := clause.From{ + Tables: nil, + Joins: []clause.Join{ + { + Type: clause.InnerJoin, + Table: clause.Table{Name: rel.JoinTable.Table}, + ON: clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: parentTableField.ForeignKey.DBName}, + Value: clause.Column{Table: existsQuery.Statement.Table, Name: parentTableField.PrimaryKey.DBName}, + }, + }, + }, + }, + }, + } + + existsQuery.Statement.AddClause(fromClause) + + existsQuery.Statement.AddClause(clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: primaryTableField.ForeignKey.DBName}, + Value: clause.Column{Table: mainQuery.Statement.Table, Name: primaryTableField.PrimaryKey.DBName}, + }, + }, + }) + + existsQuery = applyConds(existsQuery, conds) + } + + return existsQuery, nil +} + +func applyConds(existsQuery *gorm.DB, conds []interface{}) *gorm.DB { + inlineConds := make([]interface{}, 0) + + for _, cond := range conds { + if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { + existsQuery = fc(existsQuery) + } else { + inlineConds = append(inlineConds, cond) + } + } + + if len(inlineConds) > 0 { + existsQuery = existsQuery.Where(inlineConds[0], inlineConds[1:]...) + } + + return existsQuery +} diff --git a/chainable_api.go b/chainable_api.go index 8f6113cc1..214c828bd 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -212,6 +212,38 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { return } +func (db *DB) WhereHas(relation string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + + if tx.Statement.WhereHasConditions == nil { + tx.Statement.WhereHasConditions = make([]whereHasCondition, 0) + } + + tx.Statement.WhereHasConditions = append(tx.Statement.WhereHasConditions, whereHasCondition{ + IsDoesntHave: false, + Relation: relation, + Conds: args, + }) + + return +} + +func (db *DB) WhereDoesntHave(relation string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + + if tx.Statement.WhereHasConditions == nil { + tx.Statement.WhereHasConditions = make([]whereHasCondition, 0) + } + + tx.Statement.WhereHasConditions = append(tx.Statement.WhereHasConditions, whereHasCondition{ + IsDoesntHave: true, + Relation: relation, + Conds: args, + }) + + return +} + // Not add NOT conditions // // Not works similarly to where, and has the same syntax. diff --git a/statement.go b/statement.go index 736087d7a..764134d18 100644 --- a/statement.go +++ b/statement.go @@ -35,6 +35,7 @@ type Statement struct { ColumnMapping map[string]string // map columns Joins []join Preloads map[string][]interface{} + WhereHasConditions []whereHasCondition Settings sync.Map ConnPool ConnPool Schema *schema.Schema @@ -50,6 +51,12 @@ type Statement struct { Result *result } +type whereHasCondition struct { + IsDoesntHave bool + Relation string + Conds []interface{} +} + type join struct { Name string Alias string @@ -355,6 +362,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] conds = append(conds, cs.Expression) } } + + stmt.WhereHasConditions = append(stmt.WhereHasConditions, v.Statement.WhereHasConditions...) case map[interface{}]interface{}: for i, j := range v { conds = append(conds, clause.Eq{Column: i, Value: j}) @@ -540,6 +549,7 @@ func (stmt *Statement) clone() *Statement { Omits: stmt.Omits, ColumnMapping: stmt.ColumnMapping, Preloads: map[string][]interface{}{}, + WhereHasConditions: []whereHasCondition{}, ConnPool: stmt.ConnPool, Schema: stmt.Schema, Context: stmt.Context, @@ -562,6 +572,8 @@ func (stmt *Statement) clone() *Statement { newStmt.Preloads[k] = p } + newStmt.WhereHasConditions = append(newStmt.WhereHasConditions, stmt.WhereHasConditions...) + if len(stmt.Joins) > 0 { newStmt.Joins = make([]join, len(stmt.Joins)) copy(newStmt.Joins, stmt.Joins) diff --git a/tests/query_where_has_test.go b/tests/query_where_has_test.go new file mode 100644 index 000000000..77b5e7fed --- /dev/null +++ b/tests/query_where_has_test.go @@ -0,0 +1,158 @@ +package tests_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestQueryWhereHas(t *testing.T) { + DB.Create(&Language{Code: "wh_ro", Name: "Romanian"}) + + DB.Create(&[]User{ + { + Name: "u1_has_pets_with_toy", + Pets: []*Pet{ + { + Name: "u1_pet_1_with_toy", + Toy: Toy{ + Name: "u1_p1_toy_1", + }, + }, + + { + Name: "u1_pet_2_with_toy", + Toy: Toy{ + Name: "u1_p1_toy_2", + }, + }, + }, + Toys: []Toy{ + { + Name: "u1_toy_1", + }, + }, + Languages: []Language{ + { + Code: "wh_en", + Name: "English", + }, + }, + }, + + { + Name: "u2_has_pets_with_without_toy", + Pets: []*Pet{ + { + Name: "u2_pet_1_with_toy", + Toy: Toy{ + Name: "u2_p1_toy_1", + }, + }, + + { + Name: "u2_pet_2_without_toy", + }, + }, + Toys: []Toy{ + { + Name: "u2_toy_1", + }, + }, + Languages: []Language{ + { + Code: "wh_en", + Name: "English", + }, + { + Code: "wh_it", + Name: "Italian", + }, + }, + }, + { + Name: "u3_has_pets_without_toy", + Pets: []*Pet{ + { + Name: "u3_pet_1_without_toy", + }, + + { + Name: "u3_pet_2_without_toy", + }, + }, + Toys: []Toy{ + { + Name: "u3_toy_1", + }, + }, + }, + }) + + t.Run("OneToOne", func(t *testing.T) { + var err error + + var pet Pet + petLookUpName := "u1_pet_1_with_toy" + + pet = Pet{} + DB.Where("name = ?", petLookUpName).WhereHas("Toy").First(&pet) + assert.Equal(t, petLookUpName, pet.Name) + + pet = Pet{} + DB.Where("name = ?", petLookUpName).WhereHas("Toy", DB.Where("name = ?", "u1_p1_toy_1")).First(&pet) + assert.Equal(t, petLookUpName, pet.Name) + + pet = Pet{} + err = DB.Where("name = ?", petLookUpName).WhereDoesntHave("Toy").First(&pet).Error + assert.Equal(t, gorm.ErrRecordNotFound, err) + }) + + t.Run("HasMany", func(t *testing.T) { + var err error + var user User + + user = User{} + DB.Where("name = ?", "u1_has_pets_with_toy").WhereHas("Pets").First(&user) + assert.Equal(t, "u1_has_pets_with_toy", user.Name) + + user = User{} + DB.Where("name = ?", "u1_has_pets_with_toy").WhereHas("Pets", DB.Where("name = ?", "u1_pet_1_with_toy")).First(&user) + assert.Equal(t, "u1_has_pets_with_toy", user.Name) + + user = User{} + err = DB.Where("name = ?", "u1_has_pets_with_toy").WhereDoesntHave("Pets").First(&user).Error + assert.Equal(t, gorm.ErrRecordNotFound, err) + }) + + t.Run("ManyToMany", func(t *testing.T) { + var err error + var user User + + user = User{} + DB.Where("name = ?", "u1_has_pets_with_toy").WhereHas("Languages").First(&user) + assert.Equal(t, "u1_has_pets_with_toy", user.Name) + + user = User{} + err = DB.Where("name = ?", "u1_has_pets_with_toy").WhereDoesntHave("Languages").First(&user).Error + assert.Equal(t, gorm.ErrRecordNotFound, err) + + user = User{} + err = DB.Where("name = ?", "u3_has_pets_without_toy").WhereHas("Languages").First(&user).Error + assert.Equal(t, gorm.ErrRecordNotFound, err) + + var users []User + DB.WhereHas("Languages", DB.Where("code = ?", "wh_it")).Find(&users) + assert.Equal(t, 1, len(users)) + }) + + t.Run("Nested", func(t *testing.T) { + var user User + + user = User{} + DB.WhereHas("Pets", DB.WhereHas("Toy", DB.Where("name = ?", "u1_p1_toy_1"))).First(&user) + assert.Equal(t, user.Name, "u1_has_pets_with_toy") + }) +}