Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion database/db/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ func (r *Query) Chunk(size uint64, callback func(rows []db.Row) error) error {
}

func (r *Query) Count() (int64, error) {
r.conditions.Selects = []string{"COUNT(*)"}
if err := buildSelectForCount(r); err != nil {
return 0, err
}

sql, args, err := r.buildSelect()
if err != nil {
Expand Down Expand Up @@ -1304,3 +1306,32 @@ func (r *Query) trace(builder db.CommonBuilder, sql string, args []any, now *car
r.logger.Trace(r.ctx, now, builder.Explain(sql, args...), rowsAffected, err)
}
}

func buildSelectForCount(query *Query) error {
// If selectColumns only contains a raw select with spaces (rename), gorm will fail, but this case will appear when calling Paginate, so user COUNT(*) here.
// If there are multiple selectColumns, gorm will transform them into *, so no need to handle that case.
// For example: Select("name as n").Count() will fail, but Select("name", "age as a").Count() will be treated as Select("*").Count()
if len(query.conditions.Selects) > 1 {
query.conditions.Selects = []string{"COUNT(*)"}
} else if len(query.conditions.Selects) == 1 {
if str.Of(query.conditions.Selects[0]).Trim().Contains(" ") {
query.conditions.Selects = []string{"COUNT(*)"}
} else {
if query.conditions.Distinct != nil && *query.conditions.Distinct {
query.conditions.Selects = []string{fmt.Sprintf("COUNT(DISTINCT %s)", query.conditions.Selects[0])}
} else {
query.conditions.Selects = []string{fmt.Sprintf("COUNT(%s)", query.conditions.Selects[0])}
}
}
} else {
if query.conditions.Distinct != nil && *query.conditions.Distinct {
return errors.DatabaseCountDistinctWithoutColumns
} else {
query.conditions.Selects = []string{"COUNT(*)"}
}
}

query.conditions.Distinct = nil

return nil
}
314 changes: 276 additions & 38 deletions database/db/query_test.go

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion database/db/to_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ func NewToSql(query *Query, raw bool) *ToSql {
}

func (r *ToSql) Count() string {
r.query.conditions.Selects = []string{"COUNT(*)"}
if err := buildSelectForCount(r.query); err != nil {
return r.generate(r.query.readBuilder, "", nil, err)
}

sql, args, err := r.query.buildSelect()

return r.generate(r.query.readBuilder, sql, args, err)
Expand Down
19 changes: 11 additions & 8 deletions database/gorm/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/goravel/framework/support/collect"
"github.com/goravel/framework/support/database"
"github.com/goravel/framework/support/deep"
"github.com/goravel/framework/support/str"
)

const Associations = clause.Associations
Expand Down Expand Up @@ -120,7 +121,16 @@ func (r *Query) Commit() error {
}

func (r *Query) Count() (int64, error) {
query := r.resetSelect().addGlobalScopes().buildConditions()
conditions := r.conditions

// If selectColumns only contains a raw select with spaces (rename), gorm will fail, but this case will appear when calling Paginate, so user COUNT(*) here.
// If there are multiple selectColumns, gorm will transform them into *, so no need to handle that case.
// For example: Select("name as n").Count() will fail, but Select("name", "age as a").Count() will be treated as Select("*").Count()
if len(conditions.selectColumns) == 1 && str.Of(conditions.selectColumns[0]).Trim().Contains(" ") {
conditions.selectColumns = nil
}

query := r.setConditions(conditions).addGlobalScopes().buildConditions()

var count int64

Expand Down Expand Up @@ -1680,13 +1690,6 @@ func (r *Query) refreshConnection() (*Query, error) {
return query, nil
}

func (r *Query) resetSelect() *Query {
conditions := r.conditions
conditions.selectColumns = nil

return r.setConditions(conditions)
}

func (r *Query) restored(dest any) error {
return r.event(contractsorm.EventRestored, r.conditions.model, dest)
}
Expand Down
13 changes: 12 additions & 1 deletion database/gorm/to_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/goravel/framework/contracts/log"
"github.com/goravel/framework/support/database"
"github.com/goravel/framework/support/str"
)

type ToSql struct {
Expand All @@ -22,7 +23,17 @@ func NewToSql(query *Query, log log.Log, raw bool) *ToSql {
}

func (r *ToSql) Count() string {
query := r.query.addGlobalScopes().buildConditions()
conditions := r.query.conditions

// If selectColumns only contains a raw select with spaces, gorm will fail, hence ignore it here.
// If there are multiple selectColumns, gorm will transform them into *, so no need to handle that case.
// For example: Select("name as n").Count() will fail, but Select("name", "age as a").Count() will be treated as Select("*").Count()
if len(conditions.selectColumns) == 1 && str.Of(conditions.selectColumns[0]).Trim().Contains(" ") {
conditions.selectColumns = nil
}

query := r.query.setConditions(conditions).addGlobalScopes().buildConditions()

var count int64

return r.sql(query.instance.Session(&gorm.Session{DryRun: true}).Count(&count))
Expand Down
1 change: 1 addition & 0 deletions errors/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ var (
CryptMissingValueKey = New("decrypt payload error: missing value key")

DatabaseConfigNotFound = New("not found database configuration")
DatabaseCountDistinctWithoutColumns = New("cannot use Count with Distinct without specifying columns")
DatabaseTableIsRequired = New("table is required")
DatabaseForceIsRequiredInProduction = New("application in production use --force to run this command")
DatabaseSeederNotFound = New("not found %s seeder")
Expand Down
2 changes: 1 addition & 1 deletion support/constant.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package support

const (
Version = "v1.16.6"
Version = "v1.16.7"

RuntimeArtisan = "artisan"
RuntimeTest = "test"
Expand Down
80 changes: 75 additions & 5 deletions tests/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,34 @@ func (s *DBTestSuite) TestCount() {
{Name: "count_product1"},
{Name: "count_product2"},
})

count, err := query.DB().Table("products").Count()
s.NoError(err)
s.Equal(int64(2), count)

count, err = query.DB().Table("products").Where("name", "count_product1").Count()
s.NoError(err)
s.Equal(int64(1), count)

count, err = query.DB().Table("products").Select("name", "weight").Where("name", "count_product1").Count()
s.NoError(err)
s.Equal(int64(1), count)

count, err = query.DB().Table("products").Select("name as n", "weight").Where("name", "count_product1").Count()
s.NoError(err)
s.Equal(int64(1), count)

count, err = query.DB().Table("products").Select("name as n").Where("name", "count_product1").Count()
s.NoError(err)
s.Equal(int64(1), count)

count, err = query.DB().Table("products").Select("name n").Where("name", "count_product1").Count()
s.NoError(err)
s.Equal(int64(1), count)

count, err = query.DB().Table("products").Select("name").Where("name", "count_product1").Count()
s.NoError(err)
s.Equal(int64(1), count)
})
}
}
Expand Down Expand Up @@ -308,21 +333,53 @@ func (s *DBTestSuite) TestDistinct() {
for driver, query := range s.queries {
s.Run(driver, func() {
query.DB().Table("products").Insert([]Product{
{Name: "distinct_product", Weight: convert.Pointer(1)},
{Name: "distinct_product"},
{Name: "distinct_product"},
})

var products []Product
err := query.DB().Table("products").Distinct().Select("name").Get(&products)

err := query.DB().Table("products").Distinct().OrderBy("id").Get(&products)
s.NoError(err)
s.Equal(3, len(products))
s.Equal("distinct_product", products[0].Name)
s.Equal(1, *products[0].Weight)
s.Equal("distinct_product", products[1].Name)
s.Nil(products[1].Weight)
s.Equal("distinct_product", products[2].Name)
s.Nil(products[2].Weight)

err = query.DB().Table("products").Distinct().Select("name").Get(&products)
s.NoError(err)
s.Equal(1, len(products))
s.Equal("distinct_product", products[0].Name)

var products1 []Product
err = query.DB().Table("products").Distinct("name").Get(&products1)
err = query.DB().Table("products").Distinct("name").Get(&products)
s.NoError(err)
s.Equal(1, len(products))
s.Equal("distinct_product", products[0].Name)

err = query.DB().Table("products").Distinct("name", "weight").Get(&products)
s.NoError(err)
s.Equal(2, len(products))

count, err := query.DB().Table("products").Distinct().Count()
s.Error(err)
s.Equal(int64(0), count)

count, err = query.DB().Table("products").Distinct("name").Count()
s.NoError(err)
s.Equal(int64(1), count)

count, err = query.DB().Table("products").Distinct("name").Select("name").Count()
s.NoError(err)
s.Equal(int64(1), count)

// Gorm cannot support multiple distinct fields count directly, the sql will be COUNT(*), keep consistent here.
count, err = query.DB().Table("products").Distinct("name", "weight").Count()
s.NoError(err)
s.Equal(int64(3), count)
})
}
}
Expand Down Expand Up @@ -923,7 +980,6 @@ func (s *DBTestSuite) TestPaginate() {
s.Equal("paginate_product1", products[0].Name)
s.Equal("paginate_product2", products[1].Name)

products = []Product{}
err = query.DB().Table("products").WhereLike("name", "paginate_product%").Paginate(2, 2, &products, &total)
s.NoError(err)
s.Equal(2, len(products))
Expand All @@ -932,13 +988,27 @@ func (s *DBTestSuite) TestPaginate() {
s.Equal("paginate_product4", products[1].Name)

// Fix: https://github.com/goravel/goravel/issues/842
products = []Product{}
err = query.DB().Table("products").Select("name as name", "weight").WhereLike("name", "paginate_product%").Paginate(2, 2, &products, &total)
s.NoError(err)
s.Equal(2, len(products))
s.Equal(int64(5), total)
s.Equal("paginate_product3", products[0].Name)
s.Equal("paginate_product4", products[1].Name)

err = query.DB().Table("products").Select("name as name").WhereLike("name", "paginate_product%").Paginate(2, 2, &products, &total)
s.NoError(err)
s.Equal(2, len(products))
s.Equal(int64(5), total)
s.Equal("paginate_product3", products[0].Name)
s.Equal("paginate_product4", products[1].Name)

err = query.DB().Table("products").Select("name name").WhereLike("name", "paginate_product%").Paginate(2, 2, &products, &total)
s.NoError(err)
s.Equal(2, len(products))
s.Equal(int64(5), total)
s.Equal("paginate_product3", products[0].Name)
s.Equal("paginate_product4", products[1].Name)

})
}
}
Expand Down
76 changes: 68 additions & 8 deletions tests/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,33 @@ func (s *QueryTestSuite) TestCount() {
s.Nil(query.Query().Create(&user1))
s.True(user1.ID > 0)

count, err := query.Query().Model(&User{}).Where("name = ?", "count_user").Count()
count, err := query.Query().Model(&User{}).Where("name", "count_user").Count()
s.Nil(err)
s.True(count > 0)
s.Equal(int64(2), count)

count, err = query.Query().Table("users").Where("name = ?", "count_user").Count()
count, err = query.Query().Table("users").Where("avatar", "count_avatar1").Count()
s.Nil(err)
s.Equal(int64(1), count)

count, err = query.Query().Model(&User{}).Select("name", "avatar").Where("avatar", "count_avatar1").Count()
s.Nil(err)
s.True(count > 0)
s.Equal(int64(1), count)

count, err = query.Query().Model(&User{}).Select("name as n", "avatar").Where("avatar", "count_avatar1").Count()
s.Nil(err)
s.Equal(int64(1), count)

count, err = query.Query().Model(&User{}).Select("name as n").Where("avatar", "count_avatar1").Count()
s.Nil(err)
s.Equal(int64(1), count)

count, err = query.Query().Model(&User{}).Select("name n").Where("avatar", "count_avatar1").Count()
s.Nil(err)
s.Equal(int64(1), count)

count, err = query.Query().Model(&User{}).Select("name").Where("avatar", "count_avatar1").Count()
s.Nil(err)
s.Equal(int64(1), count)
})
}
}
Expand Down Expand Up @@ -1000,13 +1020,41 @@ func (s *QueryTestSuite) TestDistinct() {
s.Nil(query.Query().Create(&user1))
s.True(user1.ID > 0)

user2 := User{Name: "distinct_user", Avatar: "distinct_avatar"}
s.Nil(query.Query().Create(&user2))
s.True(user2.ID > 0)

var users []User

s.Nil(query.Query().Distinct().Find(&users))
s.Equal(3, len(users))

s.Nil(query.Query().Distinct("name").Find(&users, []uint{user.ID, user1.ID}))
s.Equal(1, len(users))

var users1 []User
s.Nil(query.Query().Distinct().Select("name").Find(&users1, []uint{user.ID, user1.ID}))
s.Equal(1, len(users1))
s.Nil(query.Query().Distinct("name", "avatar").Find(&users, []uint{user.ID, user1.ID}))
s.Equal(2, len(users))

s.Nil(query.Query().Distinct().Select("name").Find(&users, []uint{user.ID, user1.ID}))
s.Equal(1, len(users))

// Select should be set when calling Count with Distinct
count, err := query.Query().Model(&User{}).Distinct().Count()
s.Error(err)
s.Equal(int64(0), count)

count, err = query.Query().Model(&User{}).Distinct("name").Count()
s.Nil(err)
s.Equal(int64(1), count)

count, err = query.Query().Model(&User{}).Distinct("name").Select("name").Count()
s.Nil(err)
s.Equal(int64(1), count)

// Gorm cannot support multiple distinct fields count directly, the sql will be COUNT(*).
count, err = query.Query().Model(&User{}).Distinct("name", "avatar").Count()
s.Nil(err)
s.Equal(int64(3), count)
})
}
}
Expand Down Expand Up @@ -3018,9 +3066,21 @@ func (s *QueryTestSuite) TestPaginate() {
// Fix: https://github.com/goravel/goravel/issues/842
var users4 []User
var total4 int64
s.Nil(query.Query().Model(&User{}).Select("name as name").Where("name", "paginate_user").Paginate(1, 3, &users4, &total4))
s.Nil(query.Query().Model(&User{}).Select("name as name", "avatar").Where("name", "paginate_user").Paginate(1, 3, &users4, &total4))
s.Equal(3, len(users4))
s.Equal(int64(4), total4)

var users5 []User
var total5 int64
s.Nil(query.Query().Model(&User{}).Select("name as name").Where("name", "paginate_user").Paginate(1, 3, &users5, &total5))
s.Equal(3, len(users5))
s.Equal(int64(4), total5)

var users6 []User
var total6 int64
s.Nil(query.Query().Model(&User{}).Select("name name").Where("name", "paginate_user").Paginate(1, 3, &users6, &total6))
s.Equal(3, len(users6))
s.Equal(int64(4), total6)
})
}
}
Expand Down
Loading