Skip to content

Commit 2296332

Browse files
committed
First draft of INSERT func, thanks claude
1 parent cf1535b commit 2296332

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

sql/expression/function/insert.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
// Copyright 2020-2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package function
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/dolthub/go-mysql-server/sql"
21+
"github.com/dolthub/go-mysql-server/sql/types"
22+
)
23+
24+
// Insert implements the SQL function INSERT() which inserts a substring at a specified position
25+
type Insert struct {
26+
str sql.Expression
27+
pos sql.Expression
28+
length sql.Expression
29+
newStr sql.Expression
30+
}
31+
32+
var _ sql.FunctionExpression = (*Insert)(nil)
33+
var _ sql.CollationCoercible = (*Insert)(nil)
34+
35+
// NewInsert creates a new Insert expression
36+
func NewInsert(str, pos, length, newStr sql.Expression) sql.Expression {
37+
return &Insert{str, pos, length, newStr}
38+
}
39+
40+
// FunctionName implements sql.FunctionExpression
41+
func (i *Insert) FunctionName() string {
42+
return "insert"
43+
}
44+
45+
// Description implements sql.FunctionExpression
46+
func (i *Insert) Description() string {
47+
return "returns the string str, with the substring beginning at position pos and len characters long replaced by the string newstr."
48+
}
49+
50+
// Children implements the Expression interface
51+
func (i *Insert) Children() []sql.Expression {
52+
return []sql.Expression{i.str, i.pos, i.length, i.newStr}
53+
}
54+
55+
// Resolved implements the Expression interface
56+
func (i *Insert) Resolved() bool {
57+
return i.str.Resolved() && i.pos.Resolved() && i.length.Resolved() && i.newStr.Resolved()
58+
}
59+
60+
// IsNullable implements the Expression interface
61+
func (i *Insert) IsNullable() bool {
62+
return i.str.IsNullable() || i.pos.IsNullable() || i.length.IsNullable() || i.newStr.IsNullable()
63+
}
64+
65+
// Type implements the Expression interface
66+
func (i *Insert) Type() sql.Type {
67+
return types.LongText
68+
}
69+
70+
// CollationCoercibility implements the interface sql.CollationCoercible
71+
func (i *Insert) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
72+
collation, coercibility = sql.GetCoercibility(ctx, i.str)
73+
otherCollation, otherCoercibility := sql.GetCoercibility(ctx, i.newStr)
74+
return sql.ResolveCoercibility(collation, coercibility, otherCollation, otherCoercibility)
75+
}
76+
77+
// String implements the Expression interface
78+
func (i *Insert) String() string {
79+
return fmt.Sprintf("insert(%s, %s, %s, %s)", i.str, i.pos, i.length, i.newStr)
80+
}
81+
82+
// WithChildren implements the Expression interface
83+
func (i *Insert) WithChildren(children ...sql.Expression) (sql.Expression, error) {
84+
if len(children) != 4 {
85+
return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 4)
86+
}
87+
return NewInsert(children[0], children[1], children[2], children[3]), nil
88+
}
89+
90+
// Eval implements the Expression interface
91+
func (i *Insert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
92+
str, err := i.str.Eval(ctx, row)
93+
if err != nil {
94+
return nil, err
95+
}
96+
if str == nil {
97+
return nil, nil
98+
}
99+
100+
pos, err := i.pos.Eval(ctx, row)
101+
if err != nil {
102+
return nil, err
103+
}
104+
if pos == nil {
105+
return nil, nil
106+
}
107+
108+
length, err := i.length.Eval(ctx, row)
109+
if err != nil {
110+
return nil, err
111+
}
112+
if length == nil {
113+
return nil, nil
114+
}
115+
116+
newStr, err := i.newStr.Eval(ctx, row)
117+
if err != nil {
118+
return nil, err
119+
}
120+
if newStr == nil {
121+
return nil, nil
122+
}
123+
124+
// Convert all arguments to their expected types
125+
strVal, _, err := types.LongText.Convert(ctx, str)
126+
if err != nil {
127+
return nil, err
128+
}
129+
130+
posVal, _, err := types.Int64.Convert(ctx, pos)
131+
if err != nil {
132+
return nil, err
133+
}
134+
135+
lengthVal, _, err := types.Int64.Convert(ctx, length)
136+
if err != nil {
137+
return nil, err
138+
}
139+
140+
newStrVal, _, err := types.LongText.Convert(ctx, newStr)
141+
if err != nil {
142+
return nil, err
143+
}
144+
145+
s := strVal.(string)
146+
p := posVal.(int64)
147+
l := lengthVal.(int64)
148+
n := newStrVal.(string)
149+
150+
// MySQL uses 1-based indexing for position
151+
// Handle negative position or negative length
152+
if p < 1 || l < 0 {
153+
return s, nil
154+
}
155+
156+
// Convert to 0-based indexing
157+
startIdx := p - 1
158+
159+
// Handle case where position is beyond string length
160+
if startIdx >= int64(len(s)) {
161+
return s, nil
162+
}
163+
164+
// Calculate end index
165+
endIdx := startIdx + l
166+
if endIdx > int64(len(s)) {
167+
endIdx = int64(len(s))
168+
}
169+
170+
// Build the result string
171+
result := s[:startIdx] + n + s[endIdx:]
172+
return result, nil
173+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright 2020-2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package function
16+
17+
import (
18+
"testing"
19+
20+
"github.com/stretchr/testify/require"
21+
22+
"github.com/dolthub/go-mysql-server/sql"
23+
"github.com/dolthub/go-mysql-server/sql/expression"
24+
"github.com/dolthub/go-mysql-server/sql/types"
25+
)
26+
27+
func TestInsert(t *testing.T) {
28+
f := NewInsert(
29+
expression.NewGetField(0, types.LongText, "", false),
30+
expression.NewGetField(1, types.Int64, "", false),
31+
expression.NewGetField(2, types.Int64, "", false),
32+
expression.NewGetField(3, types.LongText, "", false),
33+
)
34+
35+
testCases := []struct {
36+
name string
37+
row sql.Row
38+
expected interface{}
39+
err bool
40+
}{
41+
{"null str", sql.NewRow(nil, 1, 2, "new"), nil, false},
42+
{"null pos", sql.NewRow("hello", nil, 2, "new"), nil, false},
43+
{"null length", sql.NewRow("hello", 1, nil, "new"), nil, false},
44+
{"null newStr", sql.NewRow("hello", 1, 2, nil), nil, false},
45+
{"empty string", sql.NewRow("", 1, 2, "new"), "", false},
46+
{"position is 0", sql.NewRow("hello", 0, 2, "new"), "hello", false},
47+
{"position is negative", sql.NewRow("hello", -1, 2, "new"), "hello", false},
48+
{"negative length", sql.NewRow("hello", 1, -1, "new"), "hello", false},
49+
{"position beyond string length", sql.NewRow("hello", 10, 2, "new"), "hello", false},
50+
{"normal insertion", sql.NewRow("hello", 2, 2, "xyz"), "hxyzlo", false},
51+
{"insert at beginning", sql.NewRow("hello", 1, 2, "xyz"), "xyzllo", false},
52+
{"insert at end", sql.NewRow("hello", 5, 1, "xyz"), "hellxyz", false},
53+
{"replace entire string", sql.NewRow("hello", 1, 5, "world"), "world", false},
54+
{"length exceeds string", sql.NewRow("hello", 3, 10, "world"), "heworld", false},
55+
{"empty replacement", sql.NewRow("hello", 2, 2, ""), "hlo", false},
56+
{"zero length", sql.NewRow("hello", 3, 0, "xyz"), "hexyzllo", false},
57+
}
58+
59+
for _, tt := range testCases {
60+
t.Run(tt.name, func(t *testing.T) {
61+
t.Helper()
62+
require := require.New(t)
63+
ctx := sql.NewEmptyContext()
64+
65+
v, err := f.Eval(ctx, tt.row)
66+
if tt.err {
67+
require.Error(err)
68+
} else {
69+
require.NoError(err)
70+
require.Equal(tt.expected, v)
71+
}
72+
})
73+
}
74+
}

sql/expression/function/registry.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ var BuiltIns = []sql.Function{
111111
sql.Function1{Name: "inet_ntoa", Fn: NewInetNtoa},
112112
sql.Function1{Name: "inet6_aton", Fn: NewInet6Aton},
113113
sql.Function1{Name: "inet6_ntoa", Fn: NewInet6Ntoa},
114+
sql.Function4{Name: "insert", Fn: NewInsert},
114115
sql.Function2{Name: "instr", Fn: NewInstr},
115116
sql.Function1{Name: "is_binary", Fn: NewIsBinary},
116117
sql.Function1{Name: "is_ipv4", Fn: NewIsIPv4},

0 commit comments

Comments
 (0)