diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 991ee3577e..58aa9d3d91 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -5390,6 +5390,126 @@ SELECT * FROM cte WHERE d = 2;`, {string("abc")}, }, }, + { + Query: `SELECT INSERT("Quadratic", 3, 4, "What")`, + Expected: []sql.Row{ + {string("QuWhattic")}, + }, + }, + { + Query: `SELECT INSERT("hello", 2, 2, "xyz")`, + Expected: []sql.Row{ + {string("hxyzlo")}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, 2, "xyz")`, + Expected: []sql.Row{ + {string("xyzllo")}, + }, + }, + { + Query: `SELECT INSERT("hello", 5, 1, "xyz")`, + Expected: []sql.Row{ + {string("hellxyz")}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, 5, "world")`, + Expected: []sql.Row{ + {string("world")}, + }, + }, + { + Query: `SELECT INSERT("hello", 3, 10, "world")`, + Expected: []sql.Row{ + {string("heworld")}, + }, + }, + { + Query: `SELECT INSERT("hello", 2, 2, "")`, + Expected: []sql.Row{ + {string("hlo")}, + }, + }, + { + Query: `SELECT INSERT("hello", 3, 0, "xyz")`, + Expected: []sql.Row{ + {string("hexyzllo")}, + }, + }, + { + Query: `SELECT INSERT("hello", 0, 2, "xyz")`, + Expected: []sql.Row{ + {string("hello")}, + }, + }, + { + Query: `SELECT INSERT("hello", -1, 2, "xyz")`, + Expected: []sql.Row{ + {string("hello")}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, -1, "xyz")`, + Expected: []sql.Row{ + {string("xyz")}, + }, + }, + { + Query: `SELECT INSERT("hello", 3, -1, "xyz")`, + Expected: []sql.Row{ + {string("hexyz")}, + }, + }, + { + Query: `SELECT INSERT("hello", 2, 100, "xyz")`, + Expected: []sql.Row{ + {string("hxyz")}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, 50, "world")`, + Expected: []sql.Row{ + {string("world")}, + }, + }, + { + Query: `SELECT INSERT("hello", 10, 2, "xyz")`, + Expected: []sql.Row{ + {string("hello")}, + }, + }, + { + Query: `SELECT INSERT("", 1, 2, "xyz")`, + Expected: []sql.Row{ + {string("")}, + }, + }, + { + Query: `SELECT INSERT(NULL, 1, 2, "xyz")`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT INSERT("hello", NULL, 2, "xyz")`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, NULL, "xyz")`, + Expected: []sql.Row{ + {nil}, + }, + }, + { + Query: `SELECT INSERT("hello", 1, 2, NULL)`, + Expected: []sql.Row{ + {nil}, + }, + }, { Query: `SELECT COALESCE(NULL, NULL, NULL, 'example', NULL, 1234567890)`, Expected: []sql.Row{ @@ -5426,6 +5546,30 @@ SELECT * FROM cte WHERE d = 2;`, {string("third row3")}, }, }, + { + Query: `SELECT INSERT(s, 1, 5, "new") FROM mytable ORDER BY i`, + Expected: []sql.Row{ + {string("new row")}, + {string("newd row")}, + {string("new row")}, + }, + }, + { + Query: `SELECT INSERT(s, i, 2, "XY") FROM mytable ORDER BY i`, + Expected: []sql.Row{ + {string("XYrst row")}, + {string("sXYond row")}, + {string("thXYd row")}, + }, + }, + { + Query: `SELECT INSERT(s, i + 1, i, UPPER(s)) FROM mytable ORDER BY i`, + Expected: []sql.Row{ + {string("fFIRST ROWrst row")}, + {string("seSECOND ROWnd row")}, + {string("thiTHIRD ROWrow")}, + }, + }, { Query: "SELECT version()", Expected: []sql.Row{ diff --git a/sql/expression/function/insert.go b/sql/expression/function/insert.go new file mode 100644 index 0000000000..55029521bc --- /dev/null +++ b/sql/expression/function/insert.go @@ -0,0 +1,179 @@ +// Copyright 2020-2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" +) + +// Insert implements the SQL function INSERT() which inserts a substring at a specified position +type Insert struct { + str sql.Expression + pos sql.Expression + length sql.Expression + newStr sql.Expression +} + +var _ sql.FunctionExpression = (*Insert)(nil) +var _ sql.CollationCoercible = (*Insert)(nil) + +// NewInsert creates a new Insert expression +func NewInsert(str, pos, length, newStr sql.Expression) sql.Expression { + return &Insert{str, pos, length, newStr} +} + +// FunctionName implements sql.FunctionExpression +func (i *Insert) FunctionName() string { + return "insert" +} + +// Description implements sql.FunctionExpression +func (i *Insert) Description() string { + return "returns the string str, with the substring beginning at position pos and len characters long replaced by the string newstr." +} + +// Children implements the Expression interface +func (i *Insert) Children() []sql.Expression { + return []sql.Expression{i.str, i.pos, i.length, i.newStr} +} + +// Resolved implements the Expression interface +func (i *Insert) Resolved() bool { + return i.str.Resolved() && i.pos.Resolved() && i.length.Resolved() && i.newStr.Resolved() +} + +// IsNullable implements the Expression interface +func (i *Insert) IsNullable() bool { + return i.str.IsNullable() || i.pos.IsNullable() || i.length.IsNullable() || i.newStr.IsNullable() +} + +// Type implements the Expression interface +func (i *Insert) Type() sql.Type { + return types.LongText +} + +// CollationCoercibility implements the interface sql.CollationCoercible +func (i *Insert) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + collation, coercibility = sql.GetCoercibility(ctx, i.str) + otherCollation, otherCoercibility := sql.GetCoercibility(ctx, i.newStr) + return sql.ResolveCoercibility(collation, coercibility, otherCollation, otherCoercibility) +} + +// String implements the Expression interface +func (i *Insert) String() string { + return fmt.Sprintf("insert(%s, %s, %s, %s)", i.str, i.pos, i.length, i.newStr) +} + +// WithChildren implements the Expression interface +func (i *Insert) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 4 { + return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 4) + } + return NewInsert(children[0], children[1], children[2], children[3]), nil +} + +// Eval implements the Expression interface +func (i *Insert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + str, err := i.str.Eval(ctx, row) + if err != nil { + return nil, err + } + if str == nil { + return nil, nil + } + + pos, err := i.pos.Eval(ctx, row) + if err != nil { + return nil, err + } + if pos == nil { + return nil, nil + } + + length, err := i.length.Eval(ctx, row) + if err != nil { + return nil, err + } + if length == nil { + return nil, nil + } + + newStr, err := i.newStr.Eval(ctx, row) + if err != nil { + return nil, err + } + if newStr == nil { + return nil, nil + } + + // Convert all arguments to their expected types + strVal, _, err := types.LongText.Convert(ctx, str) + if err != nil { + return nil, err + } + + posVal, _, err := types.Int64.Convert(ctx, pos) + if err != nil { + return nil, err + } + + lengthVal, _, err := types.Int64.Convert(ctx, length) + if err != nil { + return nil, err + } + + newStrVal, _, err := types.LongText.Convert(ctx, newStr) + if err != nil { + return nil, err + } + + s := strVal.(string) + p := posVal.(int64) + l := lengthVal.(int64) + n := newStrVal.(string) + + // MySQL uses 1-based indexing for position + // Handle negative position - return original string + if p < 1 { + return s, nil + } + + // Convert to 0-based indexing + startIdx := p - 1 + + // Handle case where position is beyond string length + if startIdx >= int64(len(s)) { + return s, nil + } + + // Calculate end index + // For negative length, replace from position to end of string + var endIdx int64 + if l < 0 { + endIdx = int64(len(s)) + } else { + endIdx = startIdx + l + if endIdx > int64(len(s)) { + endIdx = int64(len(s)) + } + } + + // Build the result string + result := s[:startIdx] + n + s[endIdx:] + return result, nil +} diff --git a/sql/expression/function/insert_test.go b/sql/expression/function/insert_test.go new file mode 100644 index 0000000000..8db924ef32 --- /dev/null +++ b/sql/expression/function/insert_test.go @@ -0,0 +1,78 @@ +// Copyright 2020-2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" +) + +func TestInsert(t *testing.T) { + f := NewInsert( + expression.NewGetField(0, types.LongText, "", false), + expression.NewGetField(1, types.Int64, "", false), + expression.NewGetField(2, types.Int64, "", false), + expression.NewGetField(3, types.LongText, "", false), + ) + + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + {"null str", sql.NewRow(nil, 1, 2, "new"), nil, false}, + {"null pos", sql.NewRow("hello", nil, 2, "new"), nil, false}, + {"null length", sql.NewRow("hello", 1, nil, "new"), nil, false}, + {"null newStr", sql.NewRow("hello", 1, 2, nil), nil, false}, + {"empty string", sql.NewRow("", 1, 2, "new"), "", false}, + {"position is 0", sql.NewRow("hello", 0, 2, "new"), "hello", false}, + {"position is negative", sql.NewRow("hello", -1, 2, "new"), "hello", false}, + {"negative length", sql.NewRow("hello", 1, -1, "new"), "new", false}, + {"position beyond string length", sql.NewRow("hello", 10, 2, "new"), "hello", false}, + {"normal insertion", sql.NewRow("hello", 2, 2, "xyz"), "hxyzlo", false}, + {"insert at beginning", sql.NewRow("hello", 1, 2, "xyz"), "xyzllo", false}, + {"insert at end", sql.NewRow("hello", 5, 1, "xyz"), "hellxyz", false}, + {"replace entire string", sql.NewRow("hello", 1, 5, "world"), "world", false}, + {"length exceeds string", sql.NewRow("hello", 3, 10, "world"), "heworld", false}, + {"empty replacement", sql.NewRow("hello", 2, 2, ""), "hlo", false}, + {"zero length", sql.NewRow("hello", 3, 0, "xyz"), "hexyzllo", false}, + {"negative length from middle", sql.NewRow("hello", 3, -1, "xyz"), "hexyz", false}, + {"negative length from beginning", sql.NewRow("hello", 1, -5, "xyz"), "xyz", false}, + {"large positive length", sql.NewRow("hello", 2, 100, "xyz"), "hxyz", false}, + {"length exactly matches remaining", sql.NewRow("hello", 3, 3, "xyz"), "hexyz", false}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + + v, err := f.Eval(ctx, tt.row) + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + } + }) + } +} diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 996e855afc..3030b5f9be 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -111,6 +111,7 @@ var BuiltIns = []sql.Function{ sql.Function1{Name: "inet_ntoa", Fn: NewInetNtoa}, sql.Function1{Name: "inet6_aton", Fn: NewInet6Aton}, sql.Function1{Name: "inet6_ntoa", Fn: NewInet6Ntoa}, + sql.Function4{Name: "insert", Fn: NewInsert}, sql.Function2{Name: "instr", Fn: NewInstr}, sql.Function1{Name: "is_binary", Fn: NewIsBinary}, sql.Function1{Name: "is_ipv4", Fn: NewIsIPv4},