diff --git a/finisher_api.go b/finisher_api.go index e9e35f1bf..4b94f185f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -482,6 +482,8 @@ func (db *DB) Count(count *int64) (tx *DB) { } tx.Statement.AddClause(clause.Select{Expression: expr}) + } else if strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") { + tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: tx.Statement.Selects[0]}}) } if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { diff --git a/tests/count_test.go b/tests/count_test.go index bdeba8f04..cc56650ad 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -95,6 +95,16 @@ func TestCount(t *testing.T) { t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) } + result = dryDB.Model(&User{}).Distinct("name").Joins("Team").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.* JOIN .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Select("COUNT(DISTINCT(`name`))").Joins("Team").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.* JOIN .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } + var count4 int64 if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { t.Errorf("count with join, got error: %v, count %v", err, count4)