Skip to content

Commit 9688af8

Browse files
authored
Merge pull request #3020 from dolthub/elianddb/9325-support-oct-function
Support OCT() function and fix CONV() mishandling of negative floats and empty string N
2 parents d38e0af + dba0294 commit 9688af8

File tree

6 files changed

+278
-28
lines changed

6 files changed

+278
-28
lines changed

enginetest/queries/queries.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8387,6 +8387,78 @@ SELECT * FROM cte WHERE d = 2;`,
83878387
Query: "SELECT CONV(i, 10, 2) FROM mytable",
83888388
Expected: []sql.Row{{"1"}, {"10"}, {"11"}},
83898389
},
8390+
{
8391+
Query: "SELECT OCT(8)",
8392+
Expected: []sql.Row{{"10"}},
8393+
},
8394+
{
8395+
Query: "SELECT OCT(255)",
8396+
Expected: []sql.Row{{"377"}},
8397+
},
8398+
{
8399+
Query: "SELECT OCT(0)",
8400+
Expected: []sql.Row{{"0"}},
8401+
},
8402+
{
8403+
Query: "SELECT OCT(1)",
8404+
Expected: []sql.Row{{"1"}},
8405+
},
8406+
{
8407+
Query: "SELECT OCT(NULL)",
8408+
Expected: []sql.Row{{nil}},
8409+
},
8410+
{
8411+
Query: "SELECT OCT(-1)",
8412+
Expected: []sql.Row{{"1777777777777777777777"}},
8413+
},
8414+
{
8415+
Query: "SELECT OCT(-8)",
8416+
Expected: []sql.Row{{"1777777777777777777770"}},
8417+
},
8418+
{
8419+
Query: "SELECT OCT(OCT(4))",
8420+
Expected: []sql.Row{{"4"}},
8421+
},
8422+
{
8423+
Query: "SELECT OCT('16')",
8424+
Expected: []sql.Row{{"20"}},
8425+
},
8426+
{
8427+
Query: "SELECT OCT('abc')",
8428+
Expected: []sql.Row{{"0"}},
8429+
},
8430+
{
8431+
Query: "SELECT OCT(15.7)",
8432+
Expected: []sql.Row{{"17"}},
8433+
},
8434+
{
8435+
Query: "SELECT OCT(-15.2)",
8436+
Expected: []sql.Row{{"1777777777777777777761"}},
8437+
},
8438+
{
8439+
Query: "SELECT OCT(HEX(SUBSTRING('127.0', 1, 3)))",
8440+
Expected: []sql.Row{{"1143625"}},
8441+
},
8442+
{
8443+
Query: "SELECT i, OCT(i), OCT(-i), OCT(i * 2) FROM mytable ORDER BY i",
8444+
Expected: []sql.Row{
8445+
{1, "1", "1777777777777777777777", "2"},
8446+
{2, "2", "1777777777777777777776", "4"},
8447+
{3, "3", "1777777777777777777775", "6"},
8448+
},
8449+
},
8450+
{
8451+
Query: "SELECT OCT(i) FROM mytable ORDER BY CONV(i, 10, 16)",
8452+
Expected: []sql.Row{{"1"}, {"2"}, {"3"}},
8453+
},
8454+
{
8455+
Query: "SELECT i FROM mytable WHERE OCT(s) > 0",
8456+
Expected: []sql.Row{},
8457+
},
8458+
{
8459+
Query: "SELECT s FROM mytable WHERE OCT(i*123) < 400",
8460+
Expected: []sql.Row{{"first row"}, {"second row"}},
8461+
},
83908462
{
83918463
Query: `SELECT t1.pk from one_pk join (one_pk t1 join one_pk t2 on t1.pk = t2.pk) on t1.pk = one_pk.pk and one_pk.pk = 1 join (one_pk t3 join one_pk t4 on t3.c1 is not null) on t3.pk = one_pk.pk and one_pk.c1 = 10`,
83928464
Expected: []sql.Row{{1}, {1}, {1}, {1}},

sql/expression/function/conv.go

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -136,62 +136,66 @@ func (c *Conv) WithChildren(children ...sql.Expression) (sql.Expression, error)
136136
// This conversion truncates nVal as its first subpart that is convertable.
137137
// nVal is treated as unsigned except nVal is negative.
138138
func convertFromBase(ctx *sql.Context, nVal string, fromBase interface{}) interface{} {
139-
fromBase, _, err := types.Int64.Convert(ctx, fromBase)
140-
if err != nil {
139+
if len(nVal) == 0 {
141140
return nil
142141
}
143142

144-
fromVal := int(math.Abs(float64(fromBase.(int64))))
143+
// Convert and validate fromBase
144+
baseVal, _, err := types.Int64.Convert(ctx, fromBase)
145+
if err != nil {
146+
return nil
147+
}
148+
fromVal := int(math.Abs(float64(baseVal.(int64))))
145149
if fromVal < 2 || fromVal > 36 {
146150
return nil
147151
}
148152

153+
// Handle sign
149154
negative := false
150-
var upper string
151-
var lower string
152-
if nVal[0] == '-' {
155+
switch nVal[0] {
156+
case '-':
157+
if len(nVal) == 1 {
158+
return uint64(0)
159+
}
153160
negative = true
154161
nVal = nVal[1:]
155-
} else if nVal[0] == '+' {
162+
case '+':
163+
if len(nVal) == 1 {
164+
return uint64(0)
165+
}
156166
nVal = nVal[1:]
157167
}
158168

159-
// check for upper and lower bound for given fromBase
169+
// Determine bounds based on sign
170+
var maxLen int
160171
if negative {
161-
upper = strconv.FormatInt(math.MaxInt64, fromVal)
162-
lower = strconv.FormatInt(math.MinInt64, fromVal)
163-
if len(nVal) > len(lower) {
164-
nVal = lower
165-
} else if len(nVal) > len(upper) {
166-
nVal = upper
172+
maxLen = len(strconv.FormatInt(math.MinInt64, fromVal))
173+
if len(nVal) > maxLen {
174+
// Use MinInt64 representation in the given base
175+
nVal = strconv.FormatInt(math.MinInt64, fromVal)[1:] // remove minus sign
167176
}
168177
} else {
169-
upper = strconv.FormatUint(math.MaxUint64, fromVal)
170-
lower = "0"
171-
if len(nVal) < len(lower) {
172-
nVal = lower
173-
} else if len(nVal) > len(upper) {
174-
nVal = upper
178+
maxLen = len(strconv.FormatUint(math.MaxUint64, fromVal))
179+
if len(nVal) > maxLen {
180+
// Use MaxUint64 representation in the given base
181+
nVal = strconv.FormatUint(math.MaxUint64, fromVal)
175182
}
176183
}
177184

178-
truncate := false
179-
result := uint64(0)
180-
i := 1
181-
for !truncate && i <= len(nVal) {
185+
// Find the longest valid prefix that can be converted
186+
var result uint64
187+
for i := 1; i <= len(nVal); i++ {
182188
val, err := strconv.ParseUint(nVal[:i], fromVal, 64)
183189
if err != nil {
184-
truncate = true
185-
return result
190+
break
186191
}
187192
result = val
188-
i++
189193
}
190194

191195
if negative {
196+
// MySQL returns signed value for negative inputs
192197
return int64(result) * -1
193198
}
194-
195199
return result
196200
}
197201

sql/expression/function/conv_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ func TestConv(t *testing.T) {
3535
{"n is nil", types.Int32, sql.NewRow(nil, 16, 2), nil},
3636
{"fromBase is nil", types.LongText, sql.NewRow('a', nil, 2), nil},
3737
{"toBase is nil", types.LongText, sql.NewRow('a', 16, nil), nil},
38+
{"empty n string", types.LongText, sql.NewRow("", 3, 4), nil},
39+
{"empty arg strings", types.LongText, sql.NewRow(4, "", ""), nil},
3840

3941
// invalid inputs
4042
{"invalid N", types.LongText, sql.NewRow("r", 16, 2), "0"},

sql/expression/function/oct.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package function
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/dolthub/go-mysql-server/sql"
21+
"github.com/dolthub/go-mysql-server/sql/expression"
22+
"github.com/dolthub/go-mysql-server/sql/types"
23+
)
24+
25+
// Oct function provides a string representation for the octal value of N, where N is a decimal (base 10) number.
26+
type Oct struct {
27+
n sql.Expression
28+
}
29+
30+
var _ sql.FunctionExpression = (*Oct)(nil)
31+
var _ sql.CollationCoercible = (*Oct)(nil)
32+
33+
// NewOct returns a new Oct expression.
34+
func NewOct(n sql.Expression) sql.Expression { return &Oct{n} }
35+
36+
// FunctionName implements sql.FunctionExpression.
37+
func (o *Oct) FunctionName() string {
38+
return "oct"
39+
}
40+
41+
// Description implements sql.FunctionExpression.
42+
func (o *Oct) Description() string {
43+
return "returns a string representation for octal value of N, where N is a decimal (base 10) number."
44+
}
45+
46+
// Type implements the Expression interface.
47+
func (o *Oct) Type() sql.Type {
48+
return types.LongText
49+
}
50+
51+
// IsNullable implements the Expression interface.
52+
func (o *Oct) IsNullable() bool {
53+
return o.n.IsNullable()
54+
}
55+
56+
// Eval implements the Expression interface.
57+
func (o *Oct) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
58+
// Convert a decimal (base 10) number to octal (base 8)
59+
return NewConv(
60+
o.n,
61+
expression.NewLiteral(10, types.Int64),
62+
expression.NewLiteral(8, types.Int64),
63+
).Eval(ctx, row)
64+
}
65+
66+
// Resolved implements the Expression interface.
67+
func (o *Oct) Resolved() bool {
68+
return o.n.Resolved()
69+
}
70+
71+
// Children implements the Expression interface.
72+
func (o *Oct) Children() []sql.Expression {
73+
return []sql.Expression{o.n}
74+
}
75+
76+
// WithChildren implements the Expression interface.
77+
func (o *Oct) WithChildren(children ...sql.Expression) (sql.Expression, error) {
78+
if len(children) != 1 {
79+
return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 1)
80+
}
81+
return NewOct(children[0]), nil
82+
}
83+
84+
func (o *Oct) String() string {
85+
return fmt.Sprintf("%s(%s)", o.FunctionName(), o.n)
86+
}
87+
88+
// CollationCoercibility implements the interface sql.CollationCoercible.
89+
func (*Oct) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
90+
return ctx.GetCollation(), 4 // strings with collations
91+
}

sql/expression/function/oct_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package function
16+
17+
import (
18+
"math"
19+
"testing"
20+
21+
"github.com/dolthub/go-mysql-server/sql"
22+
"github.com/dolthub/go-mysql-server/sql/expression"
23+
"github.com/dolthub/go-mysql-server/sql/types"
24+
)
25+
26+
type test struct {
27+
name string
28+
nType sql.Type
29+
row sql.Row
30+
expected interface{}
31+
}
32+
33+
func TestOct(t *testing.T) {
34+
tests := []test{
35+
// NULL input
36+
{"n is nil", types.Int32, sql.NewRow(nil), nil},
37+
38+
// Positive numbers
39+
{"positive small", types.Int32, sql.NewRow(8), "10"},
40+
{"positive medium", types.Int32, sql.NewRow(64), "100"},
41+
{"positive large", types.Int32, sql.NewRow(4095), "7777"},
42+
{"positive huge", types.Int64, sql.NewRow(123456789), "726746425"},
43+
44+
// Negative numbers
45+
{"negative small", types.Int32, sql.NewRow(-8), "1777777777777777777770"},
46+
{"negative medium", types.Int32, sql.NewRow(-64), "1777777777777777777700"},
47+
{"negative large", types.Int32, sql.NewRow(-4095), "1777777777777777770001"},
48+
49+
// Zero
50+
{"zero", types.Int32, sql.NewRow(0), "0"},
51+
52+
// String inputs
53+
{"string number", types.LongText, sql.NewRow("15"), "17"},
54+
{"alpha string", types.LongText, sql.NewRow("abc"), "0"},
55+
{"mixed string", types.LongText, sql.NewRow("123abc"), "173"},
56+
57+
// Edge cases
58+
{"max int32", types.Int32, sql.NewRow(math.MaxInt32), "17777777777"},
59+
{"min int32", types.Int32, sql.NewRow(math.MinInt32), "1777777777760000000000"},
60+
{"max int64", types.Int64, sql.NewRow(math.MaxInt64), "777777777777777777777"},
61+
{"min int64", types.Int64, sql.NewRow(math.MinInt64), "1000000000000000000000"},
62+
63+
// Decimal numbers
64+
{"decimal", types.Float64, sql.NewRow(15.5), "17"},
65+
{"negative decimal", types.Float64, sql.NewRow(-15.5), "1777777777777777777761"},
66+
}
67+
68+
for _, tt := range tests {
69+
t.Run(tt.name, func(t *testing.T) {
70+
f := NewOct(expression.NewGetField(0, tt.nType, "n", true))
71+
result, err := f.Eval(sql.NewEmptyContext(), tt.row)
72+
if err != nil {
73+
t.Fatal(err)
74+
}
75+
if result != tt.expected {
76+
t.Errorf("got %v; expected %v", result, tt.expected)
77+
}
78+
})
79+
}
80+
}

sql/expression/function/registry.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ var BuiltIns = []sql.Function{
184184
sql.Function1{Name: "ntile", Fn: window.NewNTile},
185185
sql.FunctionN{Name: "now", Fn: NewNow},
186186
sql.Function2{Name: "nullif", Fn: NewNullIf},
187+
sql.Function1{Name: "oct", Fn: NewOct},
187188
sql.Function1{Name: "octet_length", Fn: NewLength},
188189
sql.Function1{Name: "ord", Fn: NewOrd},
189190
sql.Function0{Name: "pi", Fn: NewPi},

0 commit comments

Comments
 (0)