Skip to content

Commit 5e0f5f9

Browse files
authored
fix div type (#1848)
1 parent f3cc2fe commit 5e0f5f9

File tree

3 files changed

+24
-35
lines changed

3 files changed

+24
-35
lines changed

enginetest/queries/queries.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3293,9 +3293,9 @@ Select * from (
32933293
{
32943294
Query: "SELECT unix_timestamp(timestamp_col) div 60 * 60 as timestamp_col, avg(i) from datetime_table group by 1 order by unix_timestamp(timestamp_col) div 60 * 60",
32953295
Expected: []sql.Row{
3296-
{"1577966400", 1.0},
3297-
{"1578225600", 2.0},
3298-
{"1578398400", 3.0}},
3296+
{int64(1577966400), 1.0},
3297+
{int64(1578225600), 2.0},
3298+
{int64(1578398400), 3.0}},
32993299
SkipPrepared: true,
33003300
},
33013301
{

enginetest/queries/script_queries.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,10 +1591,10 @@ var ScriptTests = []ScriptTest{
15911591
Query: `SELECT UNIX_TIMESTAMP(time) DIV 60 * 60 AS "time", avg(value) AS "value"
15921592
FROM test GROUP BY 1 ORDER BY UNIX_TIMESTAMP(test.time) DIV 60 * 60`,
15931593
Expected: []sql.Row{
1594-
{"1625133600", 4.0},
1595-
{"1625220000", 3.0},
1596-
{"1625306400", 2.0},
1597-
{"1625392800", 1.0},
1594+
{int64(1625133600), 4.0},
1595+
{int64(1625220000), 3.0},
1596+
{int64(1625306400), 2.0},
1597+
{int64(1625392800), 1.0},
15981598
},
15991599
},
16001600
},

sql/expression/div.go

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -646,33 +646,14 @@ func (i *IntDiv) IsNullable() bool {
646646

647647
// Type returns the greatest type for given operation.
648648
func (i *IntDiv) Type() sql.Type {
649-
//TODO: what if both BindVars? should be constant folded
650-
rTyp := i.Right.Type()
651-
if types.IsDeferredType(rTyp) {
652-
return rTyp
653-
}
654649
lTyp := i.Left.Type()
655-
if types.IsDeferredType(lTyp) {
656-
return lTyp
657-
}
658-
659-
if types.IsTime(lTyp) && types.IsTime(rTyp) {
660-
return types.Int64
661-
}
662-
663-
if types.IsText(lTyp) || types.IsText(rTyp) {
664-
return types.Float64
665-
}
650+
rTyp := i.Right.Type()
666651

667-
if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) {
652+
if types.IsUnsigned(lTyp) || types.IsUnsigned(rTyp) {
668653
return types.Uint64
669-
} else if types.IsSigned(lTyp) && types.IsSigned(rTyp) {
670-
return types.Int64
671654
}
672655

673-
// using max precision which is 65.
674-
defType := types.MustCreateDecimalType(65, 0)
675-
return defType
656+
return types.Int64
676657
}
677658

678659
// CollationCoercibility implements the interface sql.CollationCoercible.
@@ -729,19 +710,27 @@ func (i *IntDiv) evalLeftRight(ctx *sql.Context, row sql.Row) (interface{}, inte
729710
// The decimal types of left and right value does NOT need to be the same. Both the types
730711
// should be preserved.
731712
func (i *IntDiv) convertLeftRight(ctx *sql.Context, left interface{}, right interface{}) (interface{}, interface{}) {
732-
typ := i.Type()
733-
lIsTimeType := types.IsTime(i.Left.Type())
734-
rIsTimeType := types.IsTime(i.Right.Type())
713+
var typ sql.Type
714+
lTyp, rTyp := i.Left.Type(), i.Right.Type()
715+
lIsTimeType := types.IsTime(lTyp)
716+
rIsTimeType := types.IsTime(rTyp)
735717

736-
if types.IsInteger(typ) || types.IsFloat(typ) {
737-
left = convertValueToType(ctx, typ, left, lIsTimeType)
718+
if types.IsText(lTyp) || types.IsText(rTyp) {
719+
typ = types.Float64
720+
} else if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) {
721+
typ = types.Uint64
722+
} else if (lIsTimeType && rIsTimeType) || (types.IsSigned(lTyp) && types.IsSigned(rTyp)) {
723+
typ = types.Int64
738724
} else {
739-
left = convertToDecimalValue(left, lIsTimeType)
725+
// using max precision which is 65.
726+
typ = types.MustCreateDecimalType(65, 0)
740727
}
741728

742729
if types.IsInteger(typ) || types.IsFloat(typ) {
730+
left = convertValueToType(ctx, typ, left, lIsTimeType)
743731
right = convertValueToType(ctx, typ, right, rIsTimeType)
744732
} else {
733+
left = convertToDecimalValue(left, lIsTimeType)
745734
right = convertToDecimalValue(right, rIsTimeType)
746735
}
747736

0 commit comments

Comments
 (0)