Skip to content

Commit c6ef375

Browse files
author
James Cor
committed
fix ceil and floor typing
1 parent 811cc78 commit c6ef375

File tree

3 files changed

+57
-52
lines changed

3 files changed

+57
-52
lines changed

sql/expression/function/ceil_round_floor.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ func (c *Ceil) Description() string {
5252
// Type implements the Expression interface.
5353
func (c *Ceil) Type() sql.Type {
5454
childType := c.Child.Type()
55-
if types.IsSigned(childType) {
56-
return types.Int64
57-
}
5855
if types.IsUnsigned(childType) {
5956
return types.Uint64
6057
}
58+
if types.IsNumber(childType) {
59+
return types.Int64
60+
}
6161
return types.Float64
6262
}
6363

@@ -99,15 +99,14 @@ func (c *Ceil) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
9999
// if it's number type and not float value, it does not need ceil-ing
100100
switch num := child.(type) {
101101
case float32:
102-
return math.Ceil(float64(num)), nil
102+
child = math.Ceil(float64(num))
103103
case float64:
104-
return math.Ceil(num), nil
104+
child = math.Ceil(num)
105105
case decimal.Decimal:
106-
return num.Ceil(), nil
107-
default:
108-
num, _, _ = c.Type().Convert(ctx, child)
109-
return num, nil
106+
child = num.Ceil()
110107
}
108+
child, _, _ = c.Type().Convert(ctx, child)
109+
return child, nil
111110
}
112111

113112
// Floor returns the biggest integer value not less than X.
@@ -136,12 +135,12 @@ func (f *Floor) Description() string {
136135
// Type implements the Expression interface.
137136
func (f *Floor) Type() sql.Type {
138137
childType := f.Child.Type()
139-
if types.IsSigned(childType) {
140-
return types.Int64
141-
}
142138
if types.IsUnsigned(childType) {
143139
return types.Uint64
144140
}
141+
if types.IsNumber(childType) {
142+
return types.Int64
143+
}
145144
return types.Float64
146145
}
147146

sql/expression/function/ceil_round_floor_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ func TestCeil(t *testing.T) {
3434
err *errors.Kind
3535
}{
3636
{"float64 is nil", types.Float64, sql.NewRow(nil), nil, nil},
37-
{"float64 is ok", types.Float64, sql.NewRow(5.8), 6.0, nil},
37+
{"float64 is ok", types.Float64, sql.NewRow(5.8), int64(6), nil},
3838
{"float32 is nil", types.Float32, sql.NewRow(nil), nil, nil},
39-
{"float32 is ok", types.Float32, sql.NewRow(float32(5.8)), 6.0, nil},
39+
{"float32 is ok", types.Float32, sql.NewRow(float32(5.8)), int64(6), nil},
4040
{"int32 is nil", types.Int32, sql.NewRow(nil), nil, nil},
4141
{"int32 is ok", types.Int32, sql.NewRow(int32(6)), int64(6), nil},
4242
{"int64 is nil", types.Int64, sql.NewRow(nil), nil, nil},
@@ -69,14 +69,14 @@ func TestCeil(t *testing.T) {
6969
require.Equal(tt.expected, result)
7070
}
7171

72-
// signed -> signed, unsigned -> unsigned, everything else -> double
72+
// unsigned -> unsigned, signed -> signed, everything else -> double
7373
resType := f.Type()
74-
if types.IsSigned(tt.rowType) {
75-
require.True(types.IsSigned(resType))
76-
} else if types.IsUnsigned(resType) {
77-
require.True(types.IsUnsigned(resType))
74+
if types.IsUnsigned(tt.rowType) {
75+
require.True(resType.Equals(types.Uint64))
76+
} else if types.IsNumber(tt.rowType) {
77+
require.True(resType.Equals(types.Int64))
7878
} else {
79-
require.True(types.IsFloat(resType))
79+
require.True(resType.Equals(types.Float64))
8080
}
8181
require.False(f.IsNullable())
8282
})
@@ -129,12 +129,12 @@ func TestFloor(t *testing.T) {
129129

130130
// signed -> signed, unsigned -> unsigned, everything else -> double
131131
resType := f.Type()
132-
if types.IsSigned(tt.rowType) {
133-
require.True(types.IsSigned(resType))
134-
} else if types.IsUnsigned(resType) {
135-
require.True(types.IsUnsigned(resType))
132+
if types.IsUnsigned(tt.rowType) {
133+
require.True(resType.Equals(types.Uint64))
134+
} else if types.IsNumber(tt.rowType) {
135+
require.True(resType.Equals(types.Int64))
136136
} else {
137-
require.True(types.IsFloat(resType))
137+
require.True(resType.Equals(types.Float64))
138138
}
139139
require.False(f.IsNullable())
140140
})

sql/types/number.go

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ var (
8989
numre = regexp.MustCompile(`^[ ]*[0-9]*\.?[0-9]+`)
9090
)
9191

92+
type Round bool
93+
94+
const (
95+
ShouldTruncate Round = false
96+
ShouldRound Round = true
97+
)
98+
9299
type NumberTypeImpl_ struct {
93100
baseType query.Type
94101
displayWidth int
@@ -112,7 +119,6 @@ func CreateNumberTypeWithDisplayWidth(baseType query.Type, displayWidth int) (sq
112119
switch baseType {
113120
case sqltypes.Int8, sqltypes.Uint8, sqltypes.Int16, sqltypes.Uint16, sqltypes.Int24, sqltypes.Uint24,
114121
sqltypes.Int32, sqltypes.Uint32, sqltypes.Int64, sqltypes.Uint64, sqltypes.Float32, sqltypes.Float64:
115-
116122
// displayWidth of 0 is valid for all types, displayWidth of 1 is only valid for Int8
117123
if displayWidth == 0 || (displayWidth == 1 && baseType == sqltypes.Int8) {
118124
return NumberTypeImpl_{
@@ -151,11 +157,11 @@ func (t NumberTypeImpl_) Compare(s context.Context, a interface{}, b interface{}
151157

152158
switch t.baseType {
153159
case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64:
154-
ca, _, err := convertToUint64(t, a, false)
160+
ca, _, err := convertToUint64(t, a, ShouldTruncate)
155161
if err != nil {
156162
return 0, err
157163
}
158-
cb, _, err := convertToUint64(t, b, false)
164+
cb, _, err := convertToUint64(t, b, ShouldTruncate)
159165
if err != nil {
160166
return 0, err
161167
}
@@ -185,11 +191,11 @@ func (t NumberTypeImpl_) Compare(s context.Context, a interface{}, b interface{}
185191
}
186192
return +1, nil
187193
default:
188-
ca, _, err := convertToInt64(t, a, false)
194+
ca, _, err := convertToInt64(t, a, ShouldTruncate)
189195
if err != nil {
190196
ca = 0
191197
}
192-
cb, _, err := convertToInt64(t, b, false)
198+
cb, _, err := convertToInt64(t, b, ShouldTruncate)
193199
if err != nil {
194200
cb = 0
195201
}
@@ -224,7 +230,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
224230

225231
switch t.baseType {
226232
case sqltypes.Int8:
227-
num, _, err := convertToInt64(t, v, false)
233+
num, _, err := convertToInt64(t, v, ShouldTruncate)
228234
if err != nil && !sql.ErrTruncatedIncorrect.Is(err) {
229235
return int8(num), sql.OutOfRange, err
230236
}
@@ -236,7 +242,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
236242
}
237243
return int8(num), sql.InRange, err
238244
case sqltypes.Uint8:
239-
num, _, err := convertToInt64(t, v, false)
245+
num, _, err := convertToInt64(t, v, ShouldTruncate)
240246
if err != nil && !sql.ErrTruncatedIncorrect.Is(err) {
241247
return uint8(num), sql.OutOfRange, err
242248
}
@@ -248,7 +254,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
248254
}
249255
return uint8(num), sql.InRange, err
250256
case sqltypes.Int16:
251-
num, _, err := convertToInt64(t, v, false)
257+
num, _, err := convertToInt64(t, v, ShouldTruncate)
252258
if err != nil && !sql.ErrTruncatedIncorrect.Is(err) {
253259
return int16(num), sql.OutOfRange, err
254260
}
@@ -260,7 +266,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
260266
}
261267
return int16(num), sql.InRange, err
262268
case sqltypes.Uint16:
263-
num, _, err := convertToInt64(t, v, false)
269+
num, _, err := convertToInt64(t, v, ShouldTruncate)
264270
if err != nil && !sql.ErrTruncatedIncorrect.Is(err) {
265271
return uint16(num), sql.OutOfRange, err
266272
}
@@ -272,7 +278,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
272278
}
273279
return uint16(num), sql.InRange, err
274280
case sqltypes.Int24:
275-
num, _, err := convertToInt64(t, v, false)
281+
num, _, err := convertToInt64(t, v, ShouldTruncate)
276282
if err != nil && !sql.ErrTruncatedIncorrect.Is(err) {
277283
return int32(num), sql.OutOfRange, err
278284
}
@@ -284,7 +290,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
284290
}
285291
return int32(num), sql.InRange, err
286292
case sqltypes.Uint24:
287-
num, _, err := convertToInt64(t, v, false)
293+
num, _, err := convertToInt64(t, v, ShouldTruncate)
288294
if err != nil && !sql.ErrTruncatedIncorrect.Is(err) {
289295
return uint32(num), sql.OutOfRange, err
290296
}
@@ -296,7 +302,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
296302
}
297303
return uint32(num), sql.InRange, err
298304
case sqltypes.Int32:
299-
num, _, err := convertToInt64(t, v, false)
305+
num, _, err := convertToInt64(t, v, ShouldTruncate)
300306
if err != nil && !sql.ErrTruncatedIncorrect.Is(err) {
301307
return int32(num), sql.OutOfRange, err
302308
}
@@ -308,7 +314,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
308314
}
309315
return int32(num), sql.InRange, err
310316
case sqltypes.Uint32:
311-
num, _, err := convertToInt64(t, v, false)
317+
num, _, err := convertToInt64(t, v, ShouldTruncate)
312318
if err != nil && !sql.ErrTruncatedIncorrect.Is(err) {
313319
return uint32(num), sql.OutOfRange, err
314320
}
@@ -320,9 +326,9 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{
320326
}
321327
return uint32(num), sql.InRange, err
322328
case sqltypes.Int64:
323-
return convertToInt64(t, v, false)
329+
return convertToInt64(t, v, ShouldTruncate)
324330
case sqltypes.Uint64:
325-
return convertToUint64(t, v, false)
331+
return convertToUint64(t, v, ShouldTruncate)
326332
case sqltypes.Float32:
327333
num, err := convertToFloat64(t, v)
328334
if err != nil && !sql.ErrTruncatedIncorrect.Is(err) {
@@ -350,7 +356,7 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v interface{}) (any,
350356
}
351357
switch t.baseType {
352358
case sqltypes.Int8:
353-
num, _, err := convertToInt64(t, v, true)
359+
num, _, err := convertToInt64(t, v, ShouldRound)
354360
if err != nil {
355361
return int8(num), sql.OutOfRange, err
356362
}
@@ -362,7 +368,7 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v interface{}) (any,
362368
}
363369
return int8(num), sql.InRange, nil
364370
case sqltypes.Uint8:
365-
num, _, err := convertToInt64(t, v, true)
371+
num, _, err := convertToInt64(t, v, ShouldRound)
366372
if err != nil {
367373
return uint8(num), sql.OutOfRange, err
368374
}
@@ -374,7 +380,7 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v interface{}) (any,
374380
}
375381
return uint8(num), sql.InRange, nil
376382
case sqltypes.Int16:
377-
num, _, err := convertToInt64(t, v, true)
383+
num, _, err := convertToInt64(t, v, ShouldRound)
378384
if err != nil {
379385
return int16(num), sql.OutOfRange, err
380386
}
@@ -386,7 +392,7 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v interface{}) (any,
386392
}
387393
return int16(num), sql.InRange, nil
388394
case sqltypes.Uint16:
389-
num, _, err := convertToInt64(t, v, true)
395+
num, _, err := convertToInt64(t, v, ShouldRound)
390396
if err != nil {
391397
return uint16(num), sql.OutOfRange, err
392398
}
@@ -398,7 +404,7 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v interface{}) (any,
398404
}
399405
return uint16(num), sql.InRange, nil
400406
case sqltypes.Int24:
401-
num, _, err := convertToInt64(t, v, true)
407+
num, _, err := convertToInt64(t, v, ShouldRound)
402408
if err != nil {
403409
return int32(num), sql.OutOfRange, err
404410
}
@@ -410,7 +416,7 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v interface{}) (any,
410416
}
411417
return int32(num), sql.InRange, nil
412418
case sqltypes.Uint24:
413-
num, _, err := convertToInt64(t, v, true)
419+
num, _, err := convertToInt64(t, v, ShouldRound)
414420
if err != nil {
415421
return uint32(num), sql.OutOfRange, err
416422
}
@@ -422,7 +428,7 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v interface{}) (any,
422428
}
423429
return uint32(num), sql.InRange, nil
424430
case sqltypes.Int32:
425-
num, _, err := convertToInt64(t, v, true)
431+
num, _, err := convertToInt64(t, v, ShouldRound)
426432
if err != nil {
427433
return int32(num), sql.OutOfRange, err
428434
}
@@ -434,7 +440,7 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v interface{}) (any,
434440
}
435441
return int32(num), sql.InRange, nil
436442
case sqltypes.Uint32:
437-
num, _, err := convertToInt64(t, v, true)
443+
num, _, err := convertToInt64(t, v, ShouldRound)
438444
if err != nil {
439445
return uint32(num), sql.OutOfRange, err
440446
}
@@ -446,9 +452,9 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v interface{}) (any,
446452
}
447453
return uint32(num), sql.InRange, nil
448454
case sqltypes.Int64:
449-
return convertToInt64(t, v, true)
455+
return convertToInt64(t, v, ShouldRound)
450456
case sqltypes.Uint64:
451-
return convertToUint64(t, v, true)
457+
return convertToUint64(t, v, ShouldRound)
452458
default:
453459
return t.Convert(ctx, v)
454460
}
@@ -1045,7 +1051,7 @@ func (t NumberTypeImpl_) DisplayWidth() int {
10451051
return t.displayWidth
10461052
}
10471053

1048-
func convertToInt64(t NumberTypeImpl_, v interface{}, round bool) (int64, sql.ConvertInRange, error) {
1054+
func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertInRange, error) {
10491055
switch v := v.(type) {
10501056
case time.Time:
10511057
return v.UTC().Unix(), sql.InRange, nil
@@ -1231,7 +1237,7 @@ func convertValueToUint64(t NumberTypeImpl_, v sql.Value) (uint64, error) {
12311237
}
12321238
}
12331239

1234-
func convertToUint64(t NumberTypeImpl_, v interface{}, round bool) (uint64, sql.ConvertInRange, error) {
1240+
func convertToUint64(t NumberTypeImpl_, v any, round Round) (uint64, sql.ConvertInRange, error) {
12351241
switch v := v.(type) {
12361242
case time.Time:
12371243
return uint64(v.UTC().Unix()), sql.InRange, nil

0 commit comments

Comments
 (0)