Skip to content

Commit e58117e

Browse files
authored
Merge pull request #3079 from dolthub/zachmu/sql-funcs
INSERT string function
2 parents b38ec53 + 9d1bcdb commit e58117e

File tree

4 files changed

+402
-0
lines changed

4 files changed

+402
-0
lines changed

enginetest/queries/queries.go

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5431,6 +5431,126 @@ SELECT * FROM cte WHERE d = 2;`,
54315431
{string("abc")},
54325432
},
54335433
},
5434+
{
5435+
Query: `SELECT INSERT("Quadratic", 3, 4, "What")`,
5436+
Expected: []sql.Row{
5437+
{string("QuWhattic")},
5438+
},
5439+
},
5440+
{
5441+
Query: `SELECT INSERT("hello", 2, 2, "xyz")`,
5442+
Expected: []sql.Row{
5443+
{string("hxyzlo")},
5444+
},
5445+
},
5446+
{
5447+
Query: `SELECT INSERT("hello", 1, 2, "xyz")`,
5448+
Expected: []sql.Row{
5449+
{string("xyzllo")},
5450+
},
5451+
},
5452+
{
5453+
Query: `SELECT INSERT("hello", 5, 1, "xyz")`,
5454+
Expected: []sql.Row{
5455+
{string("hellxyz")},
5456+
},
5457+
},
5458+
{
5459+
Query: `SELECT INSERT("hello", 1, 5, "world")`,
5460+
Expected: []sql.Row{
5461+
{string("world")},
5462+
},
5463+
},
5464+
{
5465+
Query: `SELECT INSERT("hello", 3, 10, "world")`,
5466+
Expected: []sql.Row{
5467+
{string("heworld")},
5468+
},
5469+
},
5470+
{
5471+
Query: `SELECT INSERT("hello", 2, 2, "")`,
5472+
Expected: []sql.Row{
5473+
{string("hlo")},
5474+
},
5475+
},
5476+
{
5477+
Query: `SELECT INSERT("hello", 3, 0, "xyz")`,
5478+
Expected: []sql.Row{
5479+
{string("hexyzllo")},
5480+
},
5481+
},
5482+
{
5483+
Query: `SELECT INSERT("hello", 0, 2, "xyz")`,
5484+
Expected: []sql.Row{
5485+
{string("hello")},
5486+
},
5487+
},
5488+
{
5489+
Query: `SELECT INSERT("hello", -1, 2, "xyz")`,
5490+
Expected: []sql.Row{
5491+
{string("hello")},
5492+
},
5493+
},
5494+
{
5495+
Query: `SELECT INSERT("hello", 1, -1, "xyz")`,
5496+
Expected: []sql.Row{
5497+
{string("xyz")},
5498+
},
5499+
},
5500+
{
5501+
Query: `SELECT INSERT("hello", 3, -1, "xyz")`,
5502+
Expected: []sql.Row{
5503+
{string("hexyz")},
5504+
},
5505+
},
5506+
{
5507+
Query: `SELECT INSERT("hello", 2, 100, "xyz")`,
5508+
Expected: []sql.Row{
5509+
{string("hxyz")},
5510+
},
5511+
},
5512+
{
5513+
Query: `SELECT INSERT("hello", 1, 50, "world")`,
5514+
Expected: []sql.Row{
5515+
{string("world")},
5516+
},
5517+
},
5518+
{
5519+
Query: `SELECT INSERT("hello", 10, 2, "xyz")`,
5520+
Expected: []sql.Row{
5521+
{string("hello")},
5522+
},
5523+
},
5524+
{
5525+
Query: `SELECT INSERT("", 1, 2, "xyz")`,
5526+
Expected: []sql.Row{
5527+
{string("")},
5528+
},
5529+
},
5530+
{
5531+
Query: `SELECT INSERT(NULL, 1, 2, "xyz")`,
5532+
Expected: []sql.Row{
5533+
{nil},
5534+
},
5535+
},
5536+
{
5537+
Query: `SELECT INSERT("hello", NULL, 2, "xyz")`,
5538+
Expected: []sql.Row{
5539+
{nil},
5540+
},
5541+
},
5542+
{
5543+
Query: `SELECT INSERT("hello", 1, NULL, "xyz")`,
5544+
Expected: []sql.Row{
5545+
{nil},
5546+
},
5547+
},
5548+
{
5549+
Query: `SELECT INSERT("hello", 1, 2, NULL)`,
5550+
Expected: []sql.Row{
5551+
{nil},
5552+
},
5553+
},
54345554
{
54355555
Query: `SELECT COALESCE(NULL, NULL, NULL, 'example', NULL, 1234567890)`,
54365556
Expected: []sql.Row{
@@ -5467,6 +5587,30 @@ SELECT * FROM cte WHERE d = 2;`,
54675587
{string("third row3")},
54685588
},
54695589
},
5590+
{
5591+
Query: `SELECT INSERT(s, 1, 5, "new") FROM mytable ORDER BY i`,
5592+
Expected: []sql.Row{
5593+
{string("new row")},
5594+
{string("newd row")},
5595+
{string("new row")},
5596+
},
5597+
},
5598+
{
5599+
Query: `SELECT INSERT(s, i, 2, "XY") FROM mytable ORDER BY i`,
5600+
Expected: []sql.Row{
5601+
{string("XYrst row")},
5602+
{string("sXYond row")},
5603+
{string("thXYd row")},
5604+
},
5605+
},
5606+
{
5607+
Query: `SELECT INSERT(s, i + 1, i, UPPER(s)) FROM mytable ORDER BY i`,
5608+
Expected: []sql.Row{
5609+
{string("fFIRST ROWrst row")},
5610+
{string("seSECOND ROWnd row")},
5611+
{string("thiTHIRD ROWrow")},
5612+
},
5613+
},
54705614
{
54715615
Query: "SELECT version()",
54725616
Expected: []sql.Row{

sql/expression/function/insert.go

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
// Copyright 2020-2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package function
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/dolthub/go-mysql-server/sql"
21+
"github.com/dolthub/go-mysql-server/sql/types"
22+
)
23+
24+
// Insert implements the SQL function INSERT() which inserts a substring at a specified position
25+
type Insert struct {
26+
str sql.Expression
27+
pos sql.Expression
28+
length sql.Expression
29+
newStr sql.Expression
30+
}
31+
32+
var _ sql.FunctionExpression = (*Insert)(nil)
33+
var _ sql.CollationCoercible = (*Insert)(nil)
34+
35+
// NewInsert creates a new Insert expression
36+
func NewInsert(str, pos, length, newStr sql.Expression) sql.Expression {
37+
return &Insert{str, pos, length, newStr}
38+
}
39+
40+
// FunctionName implements sql.FunctionExpression
41+
func (i *Insert) FunctionName() string {
42+
return "insert"
43+
}
44+
45+
// Description implements sql.FunctionExpression
46+
func (i *Insert) Description() string {
47+
return "returns the string str, with the substring beginning at position pos and len characters long replaced by the string newstr."
48+
}
49+
50+
// Children implements the Expression interface
51+
func (i *Insert) Children() []sql.Expression {
52+
return []sql.Expression{i.str, i.pos, i.length, i.newStr}
53+
}
54+
55+
// Resolved implements the Expression interface
56+
func (i *Insert) Resolved() bool {
57+
return i.str.Resolved() && i.pos.Resolved() && i.length.Resolved() && i.newStr.Resolved()
58+
}
59+
60+
// IsNullable implements the Expression interface
61+
func (i *Insert) IsNullable() bool {
62+
return i.str.IsNullable() || i.pos.IsNullable() || i.length.IsNullable() || i.newStr.IsNullable()
63+
}
64+
65+
// Type implements the Expression interface
66+
func (i *Insert) Type() sql.Type {
67+
return types.LongText
68+
}
69+
70+
// CollationCoercibility implements the interface sql.CollationCoercible
71+
func (i *Insert) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
72+
collation, coercibility = sql.GetCoercibility(ctx, i.str)
73+
otherCollation, otherCoercibility := sql.GetCoercibility(ctx, i.newStr)
74+
return sql.ResolveCoercibility(collation, coercibility, otherCollation, otherCoercibility)
75+
}
76+
77+
// String implements the Expression interface
78+
func (i *Insert) String() string {
79+
return fmt.Sprintf("insert(%s, %s, %s, %s)", i.str, i.pos, i.length, i.newStr)
80+
}
81+
82+
// WithChildren implements the Expression interface
83+
func (i *Insert) WithChildren(children ...sql.Expression) (sql.Expression, error) {
84+
if len(children) != 4 {
85+
return nil, sql.ErrInvalidChildrenNumber.New(i, len(children), 4)
86+
}
87+
return NewInsert(children[0], children[1], children[2], children[3]), nil
88+
}
89+
90+
// Eval implements the Expression interface
91+
func (i *Insert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
92+
str, err := i.str.Eval(ctx, row)
93+
if err != nil {
94+
return nil, err
95+
}
96+
if str == nil {
97+
return nil, nil
98+
}
99+
100+
pos, err := i.pos.Eval(ctx, row)
101+
if err != nil {
102+
return nil, err
103+
}
104+
if pos == nil {
105+
return nil, nil
106+
}
107+
108+
length, err := i.length.Eval(ctx, row)
109+
if err != nil {
110+
return nil, err
111+
}
112+
if length == nil {
113+
return nil, nil
114+
}
115+
116+
newStr, err := i.newStr.Eval(ctx, row)
117+
if err != nil {
118+
return nil, err
119+
}
120+
if newStr == nil {
121+
return nil, nil
122+
}
123+
124+
// Convert all arguments to their expected types
125+
strVal, _, err := types.LongText.Convert(ctx, str)
126+
if err != nil {
127+
return nil, err
128+
}
129+
130+
posVal, _, err := types.Int64.Convert(ctx, pos)
131+
if err != nil {
132+
return nil, err
133+
}
134+
135+
lengthVal, _, err := types.Int64.Convert(ctx, length)
136+
if err != nil {
137+
return nil, err
138+
}
139+
140+
newStrVal, _, err := types.LongText.Convert(ctx, newStr)
141+
if err != nil {
142+
return nil, err
143+
}
144+
145+
s := strVal.(string)
146+
p := posVal.(int64)
147+
l := lengthVal.(int64)
148+
n := newStrVal.(string)
149+
150+
// MySQL uses 1-based indexing for position
151+
// Handle negative position - return original string
152+
if p < 1 {
153+
return s, nil
154+
}
155+
156+
// Convert to 0-based indexing
157+
startIdx := p - 1
158+
159+
// Handle case where position is beyond string length
160+
if startIdx >= int64(len(s)) {
161+
return s, nil
162+
}
163+
164+
// Calculate end index
165+
// For negative length, replace from position to end of string
166+
var endIdx int64
167+
if l < 0 {
168+
endIdx = int64(len(s))
169+
} else {
170+
endIdx = startIdx + l
171+
if endIdx > int64(len(s)) {
172+
endIdx = int64(len(s))
173+
}
174+
}
175+
176+
// Build the result string
177+
result := s[:startIdx] + n + s[endIdx:]
178+
return result, nil
179+
}

0 commit comments

Comments
 (0)