Skip to content

Commit f5dd5e9

Browse files
author
James Cor
committed
fix various math functions
1 parent 9a672b8 commit f5dd5e9

File tree

6 files changed

+117
-104
lines changed

6 files changed

+117
-104
lines changed

sql/expression/function/logarithm.go

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@ package function
1616

1717
import (
1818
"fmt"
19-
"math"
20-
"reflect"
21-
2219
"gopkg.in/src-d/go-errors.v1"
20+
"math"
2321

2422
"github.com/dolthub/go-mysql-server/sql"
2523
"github.com/dolthub/go-mysql-server/sql/expression"
2624
"github.com/dolthub/go-mysql-server/sql/types"
25+
"github.com/dolthub/vitess/go/mysql"
2726
)
2827

2928
// ErrInvalidArgumentForLogarithm is returned when an invalid argument value is passed to a
@@ -124,14 +123,13 @@ func (l *LogBase) Eval(
124123
if err != nil {
125124
return nil, err
126125
}
127-
128126
if v == nil {
129127
return nil, nil
130128
}
131129

132130
val, _, err := types.Float64.Convert(ctx, v)
133-
if err != nil {
134-
return nil, sql.ErrInvalidType.New(reflect.TypeOf(v))
131+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
132+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
135133
}
136134
return computeLog(ctx, val.(float64), l.base)
137135
}
@@ -206,28 +204,24 @@ func (l *Log) Eval(
206204
if err != nil {
207205
return nil, err
208206
}
209-
210207
if left == nil {
211208
return nil, nil
212209
}
213-
214210
lhs, _, err := types.Float64.Convert(ctx, left)
215-
if err != nil {
216-
return nil, sql.ErrInvalidType.New(reflect.TypeOf(left))
211+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
212+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
217213
}
218214

219215
right, err := l.RightChild.Eval(ctx, row)
220216
if err != nil {
221217
return nil, err
222218
}
223-
224219
if right == nil {
225220
return nil, nil
226221
}
227-
228222
rhs, _, err := types.Float64.Convert(ctx, right)
229-
if err != nil {
230-
return nil, sql.ErrInvalidType.New(reflect.TypeOf(right))
223+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
224+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
231225
}
232226

233227
// rhs becomes value, lhs becomes base
@@ -252,6 +246,6 @@ func computeLog(ctx *sql.Context, v float64, base float64) (interface{}, error)
252246
return math.Log(v), nil
253247
default:
254248
// LOG(BASE,V) is equivalent to LOG(V) / LOG(BASE).
255-
return float64(math.Log(v) / math.Log(base)), nil
249+
return math.Log(v) / math.Log(base), nil
256250
}
257251
}

sql/expression/function/logarithm_test.go

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@ func TestLn(t *testing.T) {
4040
{"Input value is null", types.Float64, sql.NewRow(nil), nil, nil},
4141
{"Input value is zero", types.Float64, sql.NewRow(0), nil, nil},
4242
{"Input value is negative", types.Float64, sql.NewRow(-1), nil, nil},
43-
{"Input value is valid string", types.Float64, sql.NewRow("2"), float64(0.6931471805599453), nil},
44-
{"Input value is invalid string", types.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType},
45-
{"Input value is valid float64", types.Float64, sql.NewRow(3), float64(1.0986122886681096), nil},
46-
{"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), float64(1.791759469228055), nil},
47-
{"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), float64(2.0794415416798357), nil},
48-
{"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), float64(2.302585092994046), nil},
43+
{"Input value is valid string", types.Float64, sql.NewRow("2"), 0.6931471805599453, nil},
44+
{"Input value is invalid string", types.Float64, sql.NewRow("aaa"), nil, nil},
45+
{"Input value is invalid string truncates", types.Float64, sql.NewRow("123.456"), 4.815884817283264, nil},
46+
{"Input value is valid float64", types.Float64, sql.NewRow(3), 1.0986122886681096, nil},
47+
{"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), 1.791759469228055, nil},
48+
{"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), 2.0794415416798357, nil},
49+
{"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), 2.302585092994046, nil},
4950
}
5051

5152
for _, tt := range testCases {
@@ -82,10 +83,11 @@ func TestLog2(t *testing.T) {
8283
{"Input value is negative", types.Float64, sql.NewRow(-1), nil, nil},
8384
{"Input value is valid string", types.Float64, sql.NewRow("2"), float64(1), nil},
8485
{"Input value is invalid string", types.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType},
85-
{"Input value is valid float64", types.Float64, sql.NewRow(3), float64(1.5849625007211563), nil},
86-
{"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), float64(2.584962500721156), nil},
87-
{"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), float64(3), nil},
88-
{"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), float64(3.321928094887362), nil},
86+
{"Input value is invalid string truncates", types.Float64, sql.NewRow("123.456"), 6.947853143387016, nil},
87+
{"Input value is valid float64", types.Float64, sql.NewRow(3), 1.5849625007211563, nil},
88+
{"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), 2.584962500721156, nil},
89+
{"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), 3.0, nil},
90+
{"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), 3.321928094887362, nil},
8991
}
9092

9193
for _, tt := range testCases {
@@ -122,10 +124,11 @@ func TestLog10(t *testing.T) {
122124
{"Input value is negative", types.Float64, sql.NewRow(-1), nil, nil},
123125
{"Input value is valid string", types.Float64, sql.NewRow("2"), float64(0.3010299956639812), nil},
124126
{"Input value is invalid string", types.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType},
125-
{"Input value is valid float64", types.Float64, sql.NewRow(3), float64(0.4771212547196624), nil},
126-
{"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), float64(0.7781512503836436), nil},
127-
{"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), float64(0.9030899869919435), nil},
128-
{"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), float64(1), nil},
127+
{"Input value is invalid string truncates", types.Float64, sql.NewRow("123.456"), 2.0915122016277716, nil},
128+
{"Input value is valid float64", types.Float64, sql.NewRow(3), 0.4771212547196624, nil},
129+
{"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), 0.7781512503836436, nil},
130+
{"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), 0.9030899869919435, nil},
131+
{"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), 1, nil},
129132
}
130133

131134
for _, tt := range testCases {
@@ -172,24 +175,26 @@ func TestLog(t *testing.T) {
172175
{"Input base is nil", []sql.Expression{expression.NewLiteral(nil, types.Float64), expression.NewLiteral(float64(10), types.Float64)}, nil, nil},
173176
{"Input base is zero", []sql.Expression{expression.NewLiteral(float64(0), types.Float64), expression.NewLiteral(float64(10), types.Float64)}, nil, nil},
174177
{"Input base is negative", []sql.Expression{expression.NewLiteral(float64(-5), types.Float64), expression.NewLiteral(float64(10), types.Float64)}, nil, nil},
175-
{"Input base is valid string", []sql.Expression{expression.NewLiteral("4", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, float64(1.6609640474436813), nil},
176-
{"Input base is invalid string", []sql.Expression{expression.NewLiteral("bbb", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, nil, sql.ErrInvalidType},
178+
{"Input base is valid string", []sql.Expression{expression.NewLiteral("4", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, 1.6609640474436813, nil},
179+
{"Input base is invalid string", []sql.Expression{expression.NewLiteral("bbb", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, nil, nil},
180+
{"Input base is invalid string truncates", []sql.Expression{expression.NewLiteral("1.23abc", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, 11.122838112203077, nil},
177181

178182
{"Input value is null", []sql.Expression{expression.NewLiteral(nil, types.Float64)}, nil, nil},
179183
{"Input value is zero", []sql.Expression{expression.NewLiteral(float64(0), types.Float64)}, nil, nil},
180184
{"Input value is negative", []sql.Expression{expression.NewLiteral(float64(-9), types.Float64)}, nil, nil},
181-
{"Input value is valid string", []sql.Expression{expression.NewLiteral("7", types.LongText)}, float64(1.9459101490553132), nil},
182-
{"Input value is invalid string", []sql.Expression{expression.NewLiteral("766j", types.LongText)}, nil, sql.ErrInvalidType},
183-
184-
{"Input base is valid float64", []sql.Expression{expression.NewLiteral(float64(5), types.Float64), expression.NewLiteral(float64(99), types.Float64)}, float64(2.855108491376949), nil},
185-
{"Input base is valid float32", []sql.Expression{expression.NewLiteral(float32(6), types.Float32), expression.NewLiteral(float64(80), types.Float64)}, float64(2.4456556306420936), nil},
186-
{"Input base is valid int64", []sql.Expression{expression.NewLiteral(int64(8), types.Int64), expression.NewLiteral(float64(64), types.Float64)}, float64(2), nil},
187-
{"Input base is valid int32", []sql.Expression{expression.NewLiteral(int32(10), types.Int32), expression.NewLiteral(float64(100), types.Float64)}, float64(2), nil},
188-
189-
{"Input value is valid float64", []sql.Expression{expression.NewLiteral(float64(5), types.Float64), expression.NewLiteral(float64(66), types.Float64)}, float64(2.6031788549643564), nil},
190-
{"Input value is valid float32", []sql.Expression{expression.NewLiteral(float32(3), types.Float32), expression.NewLiteral(float64(50), types.Float64)}, float64(3.560876795007312), nil},
191-
{"Input value is valid int64", []sql.Expression{expression.NewLiteral(int64(5), types.Int64), expression.NewLiteral(float64(77), types.Float64)}, float64(2.698958057527146), nil},
192-
{"Input value is valid int32", []sql.Expression{expression.NewLiteral(int32(4), types.Int32), expression.NewLiteral(float64(40), types.Float64)}, float64(2.6609640474436813), nil},
185+
{"Input value is valid string", []sql.Expression{expression.NewLiteral("7", types.LongText)}, 1.9459101490553132, nil},
186+
{"Input value is invalid string", []sql.Expression{expression.NewLiteral("bbb", types.LongText)}, nil, nil},
187+
{"Input value is invalid string truncates", []sql.Expression{expression.NewLiteral("766j", types.LongText)}, 6.641182169740591, nil},
188+
189+
{"Input base is valid float64", []sql.Expression{expression.NewLiteral(float64(5), types.Float64), expression.NewLiteral(float64(99), types.Float64)}, 2.855108491376949, nil},
190+
{"Input base is valid float32", []sql.Expression{expression.NewLiteral(float32(6), types.Float32), expression.NewLiteral(float64(80), types.Float64)}, 2.4456556306420936, nil},
191+
{"Input base is valid int64", []sql.Expression{expression.NewLiteral(int64(8), types.Int64), expression.NewLiteral(float64(64), types.Float64)}, 2.0, nil},
192+
{"Input base is valid int32", []sql.Expression{expression.NewLiteral(int32(10), types.Int32), expression.NewLiteral(float64(100), types.Float64)}, 2.0, nil},
193+
194+
{"Input value is valid float64", []sql.Expression{expression.NewLiteral(float64(5), types.Float64), expression.NewLiteral(float64(66), types.Float64)}, 2.6031788549643564, nil},
195+
{"Input value is valid float32", []sql.Expression{expression.NewLiteral(float32(3), types.Float32), expression.NewLiteral(float64(50), types.Float64)}, 3.560876795007312, nil},
196+
{"Input value is valid int64", []sql.Expression{expression.NewLiteral(int64(5), types.Int64), expression.NewLiteral(float64(77), types.Float64)}, 2.698958057527146, nil},
197+
{"Input value is valid int32", []sql.Expression{expression.NewLiteral(int32(4), types.Int32), expression.NewLiteral(float64(40), types.Float64)}, 2.6609640474436813, nil},
193198
}
194199

195200
for _, tt := range testCases {

sql/expression/function/math.go

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"github.com/dolthub/go-mysql-server/sql"
3030
"github.com/dolthub/go-mysql-server/sql/expression"
3131
"github.com/dolthub/go-mysql-server/sql/types"
32+
"github.com/dolthub/vitess/go/mysql"
3233
)
3334

3435
// Rand returns a random float 0 <= x < 1. If it has an argument, that argument will be used to seed the random number
@@ -129,15 +130,12 @@ func (r *Rand) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
129130
return nil, err
130131
}
131132

132-
var seed int64
133-
if types.IsNumber(r.Child.Type()) {
134-
e, _, err = types.Int64.Convert(ctx, e)
135-
if err == nil {
136-
seed = e.(int64)
137-
}
133+
e, _, err = types.Int64.Convert(ctx, e)
134+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
135+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
138136
}
139137

140-
return rand.New(rand.NewSource(seed)).Float64(), nil
138+
return rand.New(rand.NewSource(e.(int64))).Float64(), nil
141139
}
142140

143141
// Sin is the SIN function
@@ -175,8 +173,8 @@ func (s *Sin) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
175173
}
176174

177175
n, _, err := types.Float64.Convert(ctx, val)
178-
if err != nil {
179-
return nil, err
176+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
177+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
180178
}
181179

182180
return math.Sin(n.(float64)), nil
@@ -224,8 +222,8 @@ func (s *Cos) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
224222
}
225223

226224
n, _, err := types.Float64.Convert(ctx, val)
227-
if err != nil {
228-
return nil, err
225+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
226+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
229227
}
230228

231229
return math.Cos(n.(float64)), nil
@@ -273,9 +271,10 @@ func (t *Tan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
273271
}
274272

275273
n, _, err := types.Float64.Convert(ctx, val)
276-
if err != nil {
277-
return nil, err
274+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
275+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
278276
}
277+
279278
res := math.Tan(n.(float64))
280279
if math.IsNaN(res) {
281280
return nil, nil
@@ -326,8 +325,8 @@ func (a *Asin) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
326325
}
327326

328327
n, _, err := types.Float64.Convert(ctx, val)
329-
if err != nil {
330-
return nil, err
328+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
329+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
331330
}
332331

333332
res := math.Asin(n.(float64))
@@ -380,8 +379,8 @@ func (a *Acos) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
380379
}
381380

382381
n, _, err := types.Float64.Convert(ctx, val)
383-
if err != nil {
384-
return nil, err
382+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
383+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
385384
}
386385

387386
res := math.Acos(n.(float64))
@@ -489,13 +488,13 @@ func (a *Atan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
489488
}
490489

491490
nx, _, err := types.Float64.Convert(ctx, xx)
492-
if err != nil {
493-
return nil, err
491+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
492+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
494493
}
495494

496495
ny, _, err := types.Float64.Convert(ctx, yy)
497-
if err != nil {
498-
return nil, err
496+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
497+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
499498
}
500499

501500
return math.Atan2(ny.(float64), nx.(float64)), nil
@@ -548,8 +547,8 @@ func (c *Cot) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
548547
}
549548

550549
n, _, err := types.Float64.Convert(ctx, val)
551-
if err != nil {
552-
return nil, err
550+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
551+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
553552
}
554553

555554
tan := math.Tan(n.(float64))
@@ -612,8 +611,8 @@ func (d *Degrees) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
612611
}
613612

614613
n, _, err := types.Float64.Convert(ctx, val)
615-
if err != nil {
616-
return nil, err
614+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
615+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
617616
}
618617

619618
return (n.(float64) * 180.0) / math.Pi, nil
@@ -661,8 +660,8 @@ func (r *Radians) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
661660
}
662661

663662
n, _, err := types.Float64.Convert(ctx, val)
664-
if err != nil {
665-
return nil, err
663+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
664+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
666665
}
667666

668667
return (n.(float64) * math.Pi) / 180.0, nil
@@ -975,14 +974,11 @@ func (e *Exp) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
975974
}
976975

977976
v, _, err := types.Float64.Convert(ctx, val)
978-
if err != nil {
979-
// TODO: truncate
980-
ctx.Warn(1292, "Truncated incorrect DOUBLE value: '%v'", val)
981-
v = 0.0
977+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
978+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
982979
}
983980

984-
vv := v.(float64)
985-
res := math.Exp(vv)
981+
res := math.Exp(v.(float64))
986982

987983
if math.IsNaN(res) || math.IsInf(res, 0) {
988984
return nil, nil

sql/expression/function/math_test.go

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,22 @@ func TestRandWithSeed(t *testing.T) {
8484
f642 = f.(float64)
8585

8686
assert.Equal(t, f64, f642)
87+
88+
r, _ = NewRand(expression.NewLiteral("10 not a number", types.LongText))
89+
assert.Equal(t, `rand('10 not a number')`, r.String())
90+
91+
f, err = r.Eval(nil, nil)
92+
require.NoError(t, err)
93+
f64 = f.(float64)
94+
95+
assert.GreaterOrEqual(t, f64, float64(0))
96+
assert.Less(t, f64, float64(1))
97+
98+
f, err = r.Eval(nil, nil)
99+
require.NoError(t, err)
100+
f642 = f.(float64)
101+
102+
assert.Equal(t, f64, f642)
87103
}
88104

89105
func TestRadians(t *testing.T) {
@@ -94,6 +110,7 @@ func TestRadians(t *testing.T) {
94110
tf.AddSucceeding(math.Pi, int16(180))
95111
tf.AddSucceeding(math.Pi/2.0, (90))
96112
tf.AddSucceeding(2*math.Pi, 360.0)
113+
tf.AddSucceeding(math.Pi, "180.0abc")
97114
tf.Test(t, nil, nil)
98115
}
99116

@@ -107,6 +124,7 @@ func TestDegrees(t *testing.T) {
107124
{"decimal 2pi", decimal.NewFromFloat(2 * math.Pi), 360.0},
108125
{"float64 pi/2", math.Pi / 2.0, 90.0},
109126
{"float32 3*pi/2", float32(3.0 * math.Pi / 2.0), 270.0},
127+
{"string truncates", "3.1415926536ABC", 180.0},
110128
}
111129

112130
f := sql.Function1{Name: "degrees", Fn: NewDegrees}
@@ -395,13 +413,16 @@ func TestExp(t *testing.T) {
395413
exp: math.Exp(10),
396414
},
397415
{
398-
// we don't do truncation yet
399-
// https://github.com/dolthub/dolt/issues/7302
400-
name: "scientific string is truncated",
416+
name: "scientific string is evaluated",
401417
arg: expression.NewLiteral("1e1", types.Text),
402-
exp: "",
418+
exp: math.Exp(10),
419+
err: false,
420+
},
421+
{
422+
name: "scientific string is truncated",
423+
arg: expression.NewLiteral("10abc", types.Text),
424+
exp: math.Exp(10),
403425
err: false,
404-
skip: true,
405426
},
406427
}
407428

0 commit comments

Comments
 (0)