Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions callbacks/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@ func BuildQuerySQL(db *gorm.DB) {
}
}

if db.Statement.Schema != nil && len(db.Statement.WhereHasConditions) > 0 {
for _, whereHasCondition := range db.Statement.WhereHasConditions {
cl, err := newWhereHas(db, whereHasCondition.IsDoesntHave, whereHasCondition.Relation, whereHasCondition.Conds, db.Statement.Schema)
if err != nil {
_ = db.AddError(err)

continue
}

if cl == nil {
continue
}

db.Statement.AddClause(cl)
}
}

if len(db.Statement.Selects) > 0 {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
for idx, name := range db.Statement.Selects {
Expand Down
209 changes: 209 additions & 0 deletions callbacks/where_has.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
package callbacks

import (
"fmt"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

func whereHasDb(db *gorm.DB) *gorm.DB {
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})

return tx
}

var relationHandlers = map[schema.RelationshipType]func(*gorm.DB, *gorm.DB, *schema.Relationship, []interface{}) (*gorm.DB, error){
schema.Many2Many: existsMany2many,
schema.BelongsTo: existsBelongsTo,
schema.HasMany: existsHasMany,
schema.HasOne: existsHasOne,
}

func newWhereHas(db *gorm.DB, isDoesntHave bool, relationName string, conds []interface{}, s *schema.Schema) (*clause.Where, error) {
var err error

rel, ok := s.Relationships.Relations[relationName]
if !ok {
return nil, fmt.Errorf("relation %s not found", relationName)
}

tx := whereHasDb(db)

reflectResults := rel.FieldSchema.MakeSlice().Elem()

firstPrimaryField := ""
otherPrimaryFields := make([]interface{}, 0)

for i, name := range rel.FieldSchema.PrimaryFieldDBNames {
if i == 0 {
firstPrimaryField = name
} else {
otherPrimaryFields = append(otherPrimaryFields, name)
}
}

tx = tx.Model(reflectResults.Addr().Interface()).Select(firstPrimaryField, otherPrimaryFields...)
if err = tx.Statement.Parse(tx.Statement.Model); err != nil {
return nil, err
}

handler, ok := relationHandlers[rel.Type]
if !ok {
return nil, fmt.Errorf("unsupported relation type: %v", rel.Type)
}

tx, err = handler(db, tx, rel, conds)
if err != nil {
return nil, err
}

cond := "EXISTS(?)"
if isDoesntHave {
cond = "NOT " + cond
}

cl := clause.Where{
Exprs: []clause.Expression{
clause.Expr{
SQL: cond,
Vars: []interface{}{tx},
WithoutParentheses: false,
},
},
}

return &cl, nil
}

func existsHasOne(mainQuery *gorm.DB, existsQuery *gorm.DB, rel *schema.Relationship, conds []interface{}) (*gorm.DB, error) {
return existsHasMany(mainQuery, existsQuery, rel, conds)
}

func existsHasMany(mainQuery *gorm.DB, existsQuery *gorm.DB, rel *schema.Relationship, conds []interface{}) (*gorm.DB, error) {
if len(rel.References) < 1 {
return nil, fmt.Errorf("relation %s has no references", rel.Name)
}

for _, reference := range rel.References {
if reference.PrimaryKey != nil {
existsQuery.Statement.AddClause(clause.Where{
Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Table: existsQuery.Statement.Table, Name: reference.ForeignKey.DBName},
Value: clause.Column{Table: mainQuery.Statement.Table, Name: reference.PrimaryKey.DBName},
},
},
})
} else {
existsQuery.Statement.AddClause(clause.Where{
Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Table: existsQuery.Statement.Table, Name: reference.ForeignKey.DBName},
Value: reference.PrimaryValue,
},
},
})
}
}

existsQuery = applyConds(existsQuery, conds)

return existsQuery, nil
}

func existsBelongsTo(mainQuery *gorm.DB, existsQuery *gorm.DB, rel *schema.Relationship, conds []interface{}) (*gorm.DB, error) {
if len(rel.References) < 1 {
return nil, fmt.Errorf("relation %s has no references", rel.Name)
}

for _, reference := range rel.References {
existsQuery.Statement.AddClause(clause.Where{
Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Table: existsQuery.Statement.Table, Name: reference.PrimaryKey.DBName},
Value: clause.Column{Table: mainQuery.Statement.Table, Name: reference.ForeignKey.DBName},
},
},
})
}

existsQuery = applyConds(existsQuery, conds)

return existsQuery, nil
}

func existsMany2many(mainQuery *gorm.DB, existsQuery *gorm.DB, rel *schema.Relationship, conds []interface{}) (*gorm.DB, error) {
if rel.JoinTable != nil {
var parentTableField *schema.Reference = nil
var primaryTableField *schema.Reference = nil

for _, reference := range rel.References {
if !reference.OwnPrimaryKey {
parentTableField = reference
} else {
primaryTableField = reference
}
}

if parentTableField == nil {
return nil, fmt.Errorf("relation %s has no parent table field", rel.Name)
}

if primaryTableField == nil {
return nil, fmt.Errorf("relation %s has no primary table field", rel.Name)
}

fromClause := clause.From{
Tables: nil,
Joins: []clause.Join{
{
Type: clause.InnerJoin,
Table: clause.Table{Name: rel.JoinTable.Table},
ON: clause.Where{
Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: parentTableField.ForeignKey.DBName},
Value: clause.Column{Table: existsQuery.Statement.Table, Name: parentTableField.PrimaryKey.DBName},
},
},
},
},
},
}

existsQuery.Statement.AddClause(fromClause)

existsQuery.Statement.AddClause(clause.Where{
Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: primaryTableField.ForeignKey.DBName},
Value: clause.Column{Table: mainQuery.Statement.Table, Name: primaryTableField.PrimaryKey.DBName},
},
},
})

existsQuery = applyConds(existsQuery, conds)
}

return existsQuery, nil
}

func applyConds(existsQuery *gorm.DB, conds []interface{}) *gorm.DB {
inlineConds := make([]interface{}, 0)

for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
existsQuery = fc(existsQuery)
} else {
inlineConds = append(inlineConds, cond)
}
}

if len(inlineConds) > 0 {
existsQuery = existsQuery.Where(inlineConds[0], inlineConds[1:]...)
}
Comment on lines +204 to +206
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[BestPractice]

Array bounds safety issue: applyConds assumes that if inlineConds has length > 0, it can safely access inlineConds[0] and pass the rest as variadic args. However, there's a subtle issue - if inlineConds has only one element, inlineConds[1:]... will be an empty slice, which is fine. But for better readability and explicit handling:

Suggested Change
Suggested change
if len(inlineConds) > 0 {
existsQuery = existsQuery.Where(inlineConds[0], inlineConds[1:]...)
}
if len(inlineConds) > 0 {
if len(inlineConds) == 1 {
existsQuery = existsQuery.Where(inlineConds[0])
} else {
existsQuery = existsQuery.Where(inlineConds[0], inlineConds[1:]...)
}
}

Committable suggestion

Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Context for Agents
[**BestPractice**]

Array bounds safety issue: `applyConds` assumes that if `inlineConds` has length > 0, it can safely access `inlineConds[0]` and pass the rest as variadic args. However, there's a subtle issue - if `inlineConds` has only one element, `inlineConds[1:]...` will be an empty slice, which is fine. But for better readability and explicit handling:

<details>
<summary>Suggested Change</summary>

```suggestion
	if len(inlineConds) > 0 {
		if len(inlineConds) == 1 {
			existsQuery = existsQuery.Where(inlineConds[0])
		} else {
			existsQuery = existsQuery.Where(inlineConds[0], inlineConds[1:]...)
		}
	}
```

⚡ **Committable suggestion**

Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

</details>

File: callbacks/where_has.go
Line: 192

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the code is inspired by the code in the file callbacks/preload.go (line 288) and this code is correct, does not cause errors, and I think readability is better than the proposed version.


return existsQuery
}
32 changes: 32 additions & 0 deletions chainable_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,38 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
return
}

func (db *DB) WhereHas(relation string, args ...interface{}) (tx *DB) {
tx = db.getInstance()

if tx.Statement.WhereHasConditions == nil {
tx.Statement.WhereHasConditions = make([]whereHasCondition, 0)
}

tx.Statement.WhereHasConditions = append(tx.Statement.WhereHasConditions, whereHasCondition{
IsDoesntHave: false,
Relation: relation,
Conds: args,
})

return
}

func (db *DB) WhereDoesntHave(relation string, args ...interface{}) (tx *DB) {
tx = db.getInstance()

if tx.Statement.WhereHasConditions == nil {
tx.Statement.WhereHasConditions = make([]whereHasCondition, 0)
}

tx.Statement.WhereHasConditions = append(tx.Statement.WhereHasConditions, whereHasCondition{
IsDoesntHave: true,
Relation: relation,
Conds: args,
})

return
}

// Not add NOT conditions
//
// Not works similarly to where, and has the same syntax.
Expand Down
12 changes: 12 additions & 0 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type Statement struct {
ColumnMapping map[string]string // map columns
Joins []join
Preloads map[string][]interface{}
WhereHasConditions []whereHasCondition
Settings sync.Map
ConnPool ConnPool
Schema *schema.Schema
Expand All @@ -50,6 +51,12 @@ type Statement struct {
Result *result
}

type whereHasCondition struct {
IsDoesntHave bool
Relation string
Conds []interface{}
}

type join struct {
Name string
Alias string
Expand Down Expand Up @@ -355,6 +362,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
conds = append(conds, cs.Expression)
}
}

stmt.WhereHasConditions = append(stmt.WhereHasConditions, v.Statement.WhereHasConditions...)
case map[interface{}]interface{}:
for i, j := range v {
conds = append(conds, clause.Eq{Column: i, Value: j})
Expand Down Expand Up @@ -540,6 +549,7 @@ func (stmt *Statement) clone() *Statement {
Omits: stmt.Omits,
ColumnMapping: stmt.ColumnMapping,
Preloads: map[string][]interface{}{},
WhereHasConditions: []whereHasCondition{},
ConnPool: stmt.ConnPool,
Schema: stmt.Schema,
Context: stmt.Context,
Expand All @@ -562,6 +572,8 @@ func (stmt *Statement) clone() *Statement {
newStmt.Preloads[k] = p
}

newStmt.WhereHasConditions = append(newStmt.WhereHasConditions, stmt.WhereHasConditions...)

if len(stmt.Joins) > 0 {
newStmt.Joins = make([]join, len(stmt.Joins))
copy(newStmt.Joins, stmt.Joins)
Expand Down
Loading
Loading