Skip to content
Merged
60 changes: 32 additions & 28 deletions sql/expression/function/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,62 +136,66 @@ func (c *Conv) WithChildren(children ...sql.Expression) (sql.Expression, error)
// This conversion truncates nVal as its first subpart that is convertable.
// nVal is treated as unsigned except nVal is negative.
func convertFromBase(ctx *sql.Context, nVal string, fromBase interface{}) interface{} {
fromBase, _, err := types.Int64.Convert(ctx, fromBase)
if err != nil {
if len(nVal) == 0 {
return nil
}

fromVal := int(math.Abs(float64(fromBase.(int64))))
// Convert and validate fromBase
baseVal, _, err := types.Int64.Convert(ctx, fromBase)
if err != nil {
return nil
}
fromVal := int(math.Abs(float64(baseVal.(int64))))
if fromVal < 2 || fromVal > 36 {
return nil
}

// Handle sign
negative := false
var upper string
var lower string
if nVal[0] == '-' {
switch {
case nVal[0] == '-':
if len(nVal) == 1 {
return uint64(0)
}
negative = true
nVal = nVal[1:]
} else if nVal[0] == '+' {
case nVal[0] == '+':
if len(nVal) == 1 {
return uint64(0)
}
nVal = nVal[1:]
}

// check for upper and lower bound for given fromBase
// Determine bounds based on sign
var maxLen int
if negative {
upper = strconv.FormatInt(math.MaxInt64, fromVal)
lower = strconv.FormatInt(math.MinInt64, fromVal)
if len(nVal) > len(lower) {
nVal = lower
} else if len(nVal) > len(upper) {
nVal = upper
maxLen = len(strconv.FormatInt(math.MinInt64, fromVal))
if len(nVal) > maxLen {
// Use MinInt64 representation in the given base
nVal = strconv.FormatInt(math.MinInt64, fromVal)[1:] // remove minus sign
}
} else {
upper = strconv.FormatUint(math.MaxUint64, fromVal)
lower = "0"
if len(nVal) < len(lower) {
nVal = lower
} else if len(nVal) > len(upper) {
nVal = upper
maxLen = len(strconv.FormatUint(math.MaxUint64, fromVal))
if len(nVal) > maxLen {
// Use MaxUint64 representation in the given base
nVal = strconv.FormatUint(math.MaxUint64, fromVal)
}
}

truncate := false
result := uint64(0)
i := 1
for !truncate && i <= len(nVal) {
// Find the longest valid prefix that can be converted
var result uint64
for i := 1; i <= len(nVal); i++ {
val, err := strconv.ParseUint(nVal[:i], fromVal, 64)
if err != nil {
truncate = true
return result
break
}
result = val
i++
}

if negative {
// MySQL returns signed value for negative inputs
return int64(result) * -1
}

return result
}

Expand Down
2 changes: 2 additions & 0 deletions sql/expression/function/conv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ func TestConv(t *testing.T) {
{"n is nil", types.Int32, sql.NewRow(nil, 16, 2), nil},
{"fromBase is nil", types.LongText, sql.NewRow('a', nil, 2), nil},
{"toBase is nil", types.LongText, sql.NewRow('a', 16, nil), nil},
{"empty n string", types.LongText, sql.NewRow("", 3, 4), nil},
{"empty arg strings", types.LongText, sql.NewRow(4, "", ""), nil},

// invalid inputs
{"invalid N", types.LongText, sql.NewRow("r", 16, 2), "0"},
Expand Down
91 changes: 91 additions & 0 deletions sql/expression/function/oct.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package function

import (
"fmt"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/types"
)

// Oct function provides a string representation for the octal value of N, where N is a decimal (base 10) number.
type Oct struct {
n sql.Expression
}

var _ sql.FunctionExpression = (*Oct)(nil)
var _ sql.CollationCoercible = (*Oct)(nil)

// NewOct returns a new Oct expression.
func NewOct(n sql.Expression) sql.Expression { return &Oct{n} }

// FunctionName implements sql.FunctionExpression.
func (o *Oct) FunctionName() string {
return "oct"
}

// Description implements sql.FunctionExpression.
func (o *Oct) Description() string {
return "returns a string representation for octal value of N, where N is a decimal (base 10) number."
}

// Type implements the Expression interface.
func (o *Oct) Type() sql.Type {
return types.LongText
}

// IsNullable implements the Expression interface.
func (o *Oct) IsNullable() bool {
return o.n.IsNullable()
}

// Eval implements the Expression interface.
func (o *Oct) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
// Convert a decimal (base 10) number to octal (base 8)
return NewConv(
o.n,
expression.NewLiteral(10, types.Int64),
expression.NewLiteral(8, types.Int64),
).Eval(ctx, row)
}

// Resolved implements the Expression interface.
func (o *Oct) Resolved() bool {
return o.n.Resolved()
}

// Children implements the Expression interface.
func (o *Oct) Children() []sql.Expression {
return []sql.Expression{o.n}
}

// WithChildren implements the Expression interface.
func (o *Oct) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 1)
}
return NewOct(children[0]), nil
}

func (o *Oct) String() string {
return fmt.Sprintf("%s(%s)", o.FunctionName(), o.n)
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (*Oct) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return ctx.GetCollation(), 4 // strings with collations
}
80 changes: 80 additions & 0 deletions sql/expression/function/oct_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package function

import (
"math"
"testing"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/types"
)

type test struct {
name string
nType sql.Type
row sql.Row
expected interface{}
}

func TestOct(t *testing.T) {
tests := []test{
// NULL input
{"n is nil", types.Int32, sql.NewRow(nil), nil},

// Positive numbers
{"positive small", types.Int32, sql.NewRow(8), "10"},
{"positive medium", types.Int32, sql.NewRow(64), "100"},
{"positive large", types.Int32, sql.NewRow(4095), "7777"},
{"positive huge", types.Int64, sql.NewRow(123456789), "726746425"},

// Negative numbers
{"negative small", types.Int32, sql.NewRow(-8), "1777777777777777777770"},
{"negative medium", types.Int32, sql.NewRow(-64), "1777777777777777777700"},
{"negative large", types.Int32, sql.NewRow(-4095), "1777777777777777770001"},

// Zero
{"zero", types.Int32, sql.NewRow(0), "0"},

// String inputs
{"string number", types.LongText, sql.NewRow("15"), "17"},
{"alpha string", types.LongText, sql.NewRow("abc"), "0"},
{"mixed string", types.LongText, sql.NewRow("123abc"), "173"},

// Edge cases
{"max int32", types.Int32, sql.NewRow(math.MaxInt32), "17777777777"},
{"min int32", types.Int32, sql.NewRow(math.MinInt32), "1777777777760000000000"},
{"max int64", types.Int64, sql.NewRow(math.MaxInt64), "777777777777777777777"},
{"min int64", types.Int64, sql.NewRow(math.MinInt64), "1000000000000000000000"},

// Decimal numbers
{"decimal", types.Float64, sql.NewRow(15.5), "17"},
{"negative decimal", types.Float64, sql.NewRow(-15.5), "1777777777777777777761"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := NewOct(expression.NewGetField(0, tt.nType, "n", true))
result, err := f.Eval(sql.NewEmptyContext(), tt.row)
if err != nil {
t.Fatal(err)
}
if result != tt.expected {
t.Errorf("got %v; expected %v", result, tt.expected)
}
})
}
}
1 change: 1 addition & 0 deletions sql/expression/function/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ var BuiltIns = []sql.Function{
sql.Function1{Name: "ntile", Fn: window.NewNTile},
sql.FunctionN{Name: "now", Fn: NewNow},
sql.Function2{Name: "nullif", Fn: NewNullIf},
sql.Function1{Name: "oct", Fn: NewOct},
sql.Function1{Name: "octet_length", Fn: NewLength},
sql.Function1{Name: "ord", Fn: NewOrd},
sql.Function0{Name: "pi", Fn: NewPi},
Expand Down