Skip to content

Commit e30d14d

Browse files
committed
add truncate() func
1 parent 2be75e4 commit e30d14d

File tree

5 files changed

+526
-10
lines changed

5 files changed

+526
-10
lines changed

enginetest/memory_engine_test.go

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,19 +203,68 @@ func TestSingleScript(t *testing.T) {
203203
t.Skip()
204204
var scripts = []queries.ScriptTest{
205205
{
206-
Name: "AS OF propagates to nested CALLs",
206+
// https://github.com/dolthub/dolt/issues/9916
207+
Name: "TRUNCATE(X,D) function behavior",
207208
SetUpScript: []string{},
208209
Assertions: []queries.ScriptTestAssertion{
209210
{
210-
Query: "create procedure create_proc() create table t (i int primary key, j int);",
211+
Query: "SELECT TRUNCATE(1.223,1)",
211212
Expected: []sql.Row{
212-
{types.NewOkResult(0)},
213+
{"1.2"},
213214
},
214215
},
215216
{
216-
Query: "call create_proc()",
217+
Query: "SELECT TRUNCATE(1.999,1)",
217218
Expected: []sql.Row{
218-
{types.NewOkResult(0)},
219+
{"1.9"},
220+
},
221+
},
222+
{
223+
Query: "SELECT TRUNCATE(1.999,0)",
224+
Expected: []sql.Row{
225+
{"1"},
226+
},
227+
},
228+
{
229+
Query: "SELECT TRUNCATE(-1.999,1)",
230+
Expected: []sql.Row{
231+
{"-1.9"},
232+
},
233+
},
234+
{
235+
Query: "SELECT TRUNCATE(122,-2)",
236+
Expected: []sql.Row{
237+
{100},
238+
},
239+
},
240+
{
241+
Query: "SELECT TRUNCATE(10.28*100,0)",
242+
Expected: []sql.Row{
243+
{"1028"},
244+
},
245+
},
246+
{
247+
Query: "SELECT TRUNCATE(NULL,1)",
248+
Expected: []sql.Row{
249+
{nil},
250+
},
251+
},
252+
{
253+
Query: "SELECT TRUNCATE(1.223,NULL)",
254+
Expected: []sql.Row{
255+
{nil},
256+
},
257+
},
258+
{
259+
Query: "SELECT TRUNCATE(0.5,0)",
260+
Expected: []sql.Row{
261+
{"0"},
262+
},
263+
},
264+
{
265+
Query: "SELECT TRUNCATE(-0.5,0)",
266+
Expected: []sql.Row{
267+
{"0"},
219268
},
220269
},
221270
},

sql/expression/function/ceil_round_floor.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,26 @@ import (
2626
"github.com/dolthub/go-mysql-server/sql/types"
2727
)
2828

29+
// numericRetType returns the appropriate return type for numeric functions
30+
// like ROUND() and TRUNCATE() according to MySQL specification:
31+
// Integer types return BIGINT
32+
// Floating-point types or non-numeric types return DOUBLE
33+
// DECIMAL values return DECIMAL
34+
func numericRetType(inputType sql.Type) sql.Type {
35+
if types.IsSigned(inputType) || types.IsUnsigned(inputType) {
36+
return types.Int64 // BIGINT
37+
} else if types.IsFloat(inputType) {
38+
return types.Float64 // DOUBLE
39+
} else if types.IsDecimal(inputType) {
40+
return inputType // DECIMAL (same type)
41+
} else if types.IsTextBlob(inputType) {
42+
return types.Float64 // DOUBLE for non-numeric types
43+
}
44+
45+
// Default fallback
46+
return types.Float64
47+
}
48+
2949
// Ceil returns the smallest integer value not less than X.
3050
type Ceil struct {
3151
expression.UnaryExpression
@@ -321,11 +341,7 @@ func (r *Round) Resolved() bool {
321341

322342
// Type implements the Expression interface.
323343
func (r *Round) Type() sql.Type {
324-
leftChildType := r.LeftChild.Type()
325-
if types.IsNumber(leftChildType) {
326-
return leftChildType
327-
}
328-
return types.Int32
344+
return numericRetType(r.LeftChild.Type())
329345
}
330346

331347
// CollationCoercibility implements the interface sql.CollationCoercible.

sql/expression/function/registry.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ var BuiltIns = []sql.Function{
333333
sql.FunctionN{Name: "distance", Fn: vector.NewGenericDistance},
334334
sql.Function1{Name: "string_to_vector", Fn: vector.NewStringToVector},
335335
sql.Function1{Name: "to_vector", Fn: vector.NewStringToVector},
336+
sql.FunctionN{Name: "truncate", Fn: NewTruncate},
336337
sql.Function1{Name: "vec_fromtext", Fn: vector.NewStringToVector},
337338
sql.Function1{Name: "vector_to_string", Fn: vector.NewVectorToString},
338339
sql.Function1{Name: "from_vector", Fn: vector.NewVectorToString},
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package function
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/dolthub/vitess/go/mysql"
21+
"github.com/shopspring/decimal"
22+
23+
"github.com/dolthub/go-mysql-server/sql"
24+
"github.com/dolthub/go-mysql-server/sql/expression"
25+
"github.com/dolthub/go-mysql-server/sql/types"
26+
)
27+
28+
// Truncate truncates a number to a specified number of decimal places.
29+
// If D is 0, the result has no decimal point or fractional part.
30+
// D can be negative to cause D digits left of the decimal point of the value X to become zero.
31+
// If X or D is NULL, the function returns NULL.
32+
// All numbers are rounded toward zero.
33+
type Truncate struct {
34+
expression.BinaryExpressionStub
35+
}
36+
37+
var _ sql.FunctionExpression = (*Truncate)(nil)
38+
var _ sql.CollationCoercible = (*Truncate)(nil)
39+
40+
// NewTruncate returns a new Truncate expression.
41+
func NewTruncate(args ...sql.Expression) (sql.Expression, error) {
42+
argLen := len(args)
43+
if argLen != 2 {
44+
return nil, sql.ErrInvalidArgumentNumber.New("TRUNCATE", "2", argLen)
45+
}
46+
47+
return &Truncate{expression.BinaryExpressionStub{LeftChild: args[0], RightChild: args[1]}}, nil
48+
}
49+
50+
// FunctionName implements sql.FunctionExpression
51+
func (t *Truncate) FunctionName() string {
52+
return "truncate"
53+
}
54+
55+
// Description implements sql.FunctionExpression
56+
func (t *Truncate) Description() string {
57+
return "truncates the number to decimals decimal places."
58+
}
59+
60+
// Children implements the Expression interface.
61+
func (t *Truncate) Children() []sql.Expression {
62+
return t.BinaryExpressionStub.Children()
63+
}
64+
65+
// Eval implements the Expression interface.
66+
func (t *Truncate) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
67+
val, err := t.LeftChild.Eval(ctx, row)
68+
if err != nil {
69+
return nil, err
70+
}
71+
if val == nil {
72+
return nil, nil
73+
}
74+
75+
val, _, err = types.InternalDecimalType.Convert(ctx, val)
76+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
77+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
78+
}
79+
80+
prec, err := t.RightChild.Eval(ctx, row)
81+
if err != nil {
82+
return nil, err
83+
}
84+
if prec == nil {
85+
return nil, nil
86+
}
87+
prec, _, err = types.Int32.Convert(ctx, prec)
88+
if err != nil {
89+
if !sql.ErrTruncatedIncorrect.Is(err) {
90+
return nil, err
91+
}
92+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
93+
}
94+
precision := prec.(int32)
95+
96+
// MySQL cuts off at 30 for larger values
97+
// TODO: these limits are fine only because we can't handle decimals larger than this
98+
if precision > types.DecimalTypeMaxPrecision {
99+
precision = types.DecimalTypeMaxPrecision
100+
}
101+
if precision < -types.DecimalTypeMaxScale {
102+
precision = -types.DecimalTypeMaxScale
103+
}
104+
105+
var res interface{}
106+
var tmp decimal.Decimal
107+
108+
if precision < 0 {
109+
// For negative precision, we need to truncate digits to the left of decimal point
110+
// This is different from the decimal library's Truncate method
111+
// We need to divide by 10^|precision|, truncate, then multiply back
112+
multiplier := decimal.NewFromInt(1)
113+
for i := int32(0); i < -precision; i++ {
114+
multiplier = multiplier.Mul(decimal.NewFromInt(10))
115+
}
116+
tmp = val.(decimal.Decimal).Div(multiplier).Truncate(0).Mul(multiplier)
117+
} else {
118+
// For positive precision, use the standard Truncate method
119+
tmp = val.(decimal.Decimal).Truncate(precision)
120+
}
121+
122+
lType := t.LeftChild.Type()
123+
if types.IsSigned(lType) {
124+
res, _, err = types.Int64.Convert(ctx, tmp)
125+
} else if types.IsUnsigned(lType) {
126+
res, _, err = types.Uint64.Convert(ctx, tmp)
127+
} else if types.IsFloat(lType) {
128+
res, _, err = types.Float64.Convert(ctx, tmp)
129+
} else if types.IsDecimal(lType) {
130+
res = tmp
131+
} else if types.IsTextBlob(lType) {
132+
res, _, err = types.Float64.Convert(ctx, tmp)
133+
} else {
134+
// Default to Float64 for unknown types
135+
res, _, err = types.Float64.Convert(ctx, tmp)
136+
}
137+
if err != nil && sql.ErrTruncatedIncorrect.Is(err) {
138+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
139+
err = nil
140+
}
141+
return res, err
142+
}
143+
144+
// IsNullable implements the Expression interface.
145+
func (t *Truncate) IsNullable() bool {
146+
return t.LeftChild.IsNullable() || t.RightChild.IsNullable()
147+
}
148+
149+
func (t *Truncate) String() string {
150+
return fmt.Sprintf("%s(%s,%s)", t.FunctionName(), t.LeftChild.String(), t.RightChild.String())
151+
}
152+
153+
// Resolved implements the Expression interface.
154+
func (t *Truncate) Resolved() bool {
155+
return t.LeftChild.Resolved() && t.RightChild.Resolved()
156+
}
157+
158+
// Type implements the Expression interface.
159+
func (t *Truncate) Type() sql.Type {
160+
return numericRetType(t.LeftChild.Type())
161+
}
162+
163+
// CollationCoercibility implements the interface sql.CollationCoercible.
164+
func (*Truncate) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
165+
return sql.Collation_binary, 5
166+
}
167+
168+
// WithChildren implements the Expression interface.
169+
func (t *Truncate) WithChildren(children ...sql.Expression) (sql.Expression, error) {
170+
return NewTruncate(children...)
171+
}

0 commit comments

Comments
 (0)