Skip to content

Commit 0ea893b

Browse files
committed
New string func
1 parent 9d1bcdb commit 0ea893b

File tree

4 files changed

+508
-0
lines changed

4 files changed

+508
-0
lines changed

enginetest/queries/queries.go

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5570,6 +5570,134 @@ SELECT * FROM cte WHERE d = 2;`,
55705570
{string("thiTHIRD ROWrow")},
55715571
},
55725572
},
5573+
{
5574+
Query: `SELECT EXPORT_SET(5, "Y", "N", ",", 4)`,
5575+
Expected: []sql.Row{
5576+
{string("Y,N,Y,N")},
5577+
},
5578+
},
5579+
{
5580+
Query: `SELECT EXPORT_SET(6, "1", "0", ",", 10)`,
5581+
Expected: []sql.Row{
5582+
{string("0,1,1,0,0,0,0,0,0,0")},
5583+
},
5584+
},
5585+
{
5586+
Query: `SELECT EXPORT_SET(0, "1", "0", ",", 4)`,
5587+
Expected: []sql.Row{
5588+
{string("0,0,0,0")},
5589+
},
5590+
},
5591+
{
5592+
Query: `SELECT EXPORT_SET(15, "1", "0", ",", 4)`,
5593+
Expected: []sql.Row{
5594+
{string("1,1,1,1")},
5595+
},
5596+
},
5597+
{
5598+
Query: `SELECT EXPORT_SET(1, "T", "F", ",", 3)`,
5599+
Expected: []sql.Row{
5600+
{string("T,F,F")},
5601+
},
5602+
},
5603+
{
5604+
Query: `SELECT EXPORT_SET(5, "1", "0", "|", 4)`,
5605+
Expected: []sql.Row{
5606+
{string("1|0|1|0")},
5607+
},
5608+
},
5609+
{
5610+
Query: `SELECT EXPORT_SET(5, "1", "0", "", 4)`,
5611+
Expected: []sql.Row{
5612+
{string("1010")},
5613+
},
5614+
},
5615+
{
5616+
Query: `SELECT EXPORT_SET(5, "ON", "OFF", ",", 4)`,
5617+
Expected: []sql.Row{
5618+
{string("ON,OFF,ON,OFF")},
5619+
},
5620+
},
5621+
{
5622+
Query: `SELECT EXPORT_SET(255, "1", "0", ",", 8)`,
5623+
Expected: []sql.Row{
5624+
{string("1,1,1,1,1,1,1,1")},
5625+
},
5626+
},
5627+
{
5628+
Query: `SELECT EXPORT_SET(1024, "1", "0", ",", 12)`,
5629+
Expected: []sql.Row{
5630+
{string("0,0,0,0,0,0,0,0,0,0,1,0")},
5631+
},
5632+
},
5633+
{
5634+
Query: `SELECT EXPORT_SET(5, "1", "0")`,
5635+
Expected: []sql.Row{
5636+
{string("1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0")},
5637+
},
5638+
},
5639+
{
5640+
Query: `SELECT EXPORT_SET(5, "1", "0", ",", 1)`,
5641+
Expected: []sql.Row{
5642+
{string("1")},
5643+
},
5644+
},
5645+
{
5646+
Query: `SELECT EXPORT_SET(-1, "1", "0", ",", 4)`,
5647+
Expected: []sql.Row{
5648+
{string("1,1,1,1")},
5649+
},
5650+
},
5651+
{
5652+
Query: `SELECT EXPORT_SET(NULL, "1", "0", ",", 4)`,
5653+
Expected: []sql.Row{
5654+
{nil},
5655+
},
5656+
},
5657+
{
5658+
Query: `SELECT EXPORT_SET(5, NULL, "0", ",", 4)`,
5659+
Expected: []sql.Row{
5660+
{nil},
5661+
},
5662+
},
5663+
{
5664+
Query: `SELECT EXPORT_SET(5, "1", NULL, ",", 4)`,
5665+
Expected: []sql.Row{
5666+
{nil},
5667+
},
5668+
},
5669+
{
5670+
Query: `SELECT EXPORT_SET(5, "1", "0", NULL, 4)`,
5671+
Expected: []sql.Row{
5672+
{nil},
5673+
},
5674+
},
5675+
{
5676+
Query: `SELECT EXPORT_SET(5, "1", "0", ",", NULL)`,
5677+
Expected: []sql.Row{
5678+
{nil},
5679+
},
5680+
},
5681+
{
5682+
Query: `SELECT EXPORT_SET("5", "1", "0", ",", 4)`,
5683+
Expected: []sql.Row{
5684+
{string("1,0,1,0")},
5685+
},
5686+
},
5687+
{
5688+
Query: `SELECT EXPORT_SET(5.7, "1", "0", ",", 4)`,
5689+
Expected: []sql.Row{
5690+
{string("0,1,1,0")},
5691+
},
5692+
},
5693+
{
5694+
Query: `SELECT EXPORT_SET(i, "1", "0", ",", 4) FROM mytable ORDER BY i`,
5695+
Expected: []sql.Row{
5696+
{string("1,0,0,0")},
5697+
{string("0,1,0,0")},
5698+
{string("1,1,0,0")},
5699+
},
5700+
},
55735701
{
55745702
Query: "SELECT version()",
55755703
Expected: []sql.Row{

sql/expression/function/export_set.go

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
// Copyright 2020-2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package function
16+
17+
import (
18+
"fmt"
19+
"strings"
20+
21+
"github.com/dolthub/go-mysql-server/sql"
22+
"github.com/dolthub/go-mysql-server/sql/types"
23+
)
24+
25+
// ExportSet implements the SQL function EXPORT_SET() which returns a string representation of bits in a number
26+
type ExportSet struct {
27+
bits sql.Expression
28+
on sql.Expression
29+
off sql.Expression
30+
separator sql.Expression
31+
numberOfBits sql.Expression
32+
}
33+
34+
var _ sql.FunctionExpression = (*ExportSet)(nil)
35+
var _ sql.CollationCoercible = (*ExportSet)(nil)
36+
37+
// NewExportSet creates a new ExportSet expression
38+
func NewExportSet(args ...sql.Expression) (sql.Expression, error) {
39+
if len(args) < 3 || len(args) > 5 {
40+
return nil, sql.ErrInvalidArgumentNumber.New("EXPORT_SET", "3, 4, or 5", len(args))
41+
}
42+
43+
var separator, numberOfBits sql.Expression
44+
if len(args) >= 4 {
45+
separator = args[3]
46+
}
47+
if len(args) == 5 {
48+
numberOfBits = args[4]
49+
}
50+
51+
return &ExportSet{
52+
bits: args[0],
53+
on: args[1],
54+
off: args[2],
55+
separator: separator,
56+
numberOfBits: numberOfBits,
57+
}, nil
58+
}
59+
60+
// FunctionName implements sql.FunctionExpression
61+
func (e *ExportSet) FunctionName() string {
62+
return "export_set"
63+
}
64+
65+
// Description implements sql.FunctionExpression
66+
func (e *ExportSet) Description() string {
67+
return "returns a string such that for every bit set in the value bits, you get an on string and for every unset bit, you get an off string."
68+
}
69+
70+
// Children implements the Expression interface
71+
func (e *ExportSet) Children() []sql.Expression {
72+
children := []sql.Expression{e.bits, e.on, e.off}
73+
if e.separator != nil {
74+
children = append(children, e.separator)
75+
}
76+
if e.numberOfBits != nil {
77+
children = append(children, e.numberOfBits)
78+
}
79+
return children
80+
}
81+
82+
// Resolved implements the Expression interface
83+
func (e *ExportSet) Resolved() bool {
84+
for _, child := range e.Children() {
85+
if !child.Resolved() {
86+
return false
87+
}
88+
}
89+
return true
90+
}
91+
92+
// IsNullable implements the Expression interface
93+
func (e *ExportSet) IsNullable() bool {
94+
for _, child := range e.Children() {
95+
if child.IsNullable() {
96+
return true
97+
}
98+
}
99+
return false
100+
}
101+
102+
// Type implements the Expression interface
103+
func (e *ExportSet) Type() sql.Type {
104+
return types.LongText
105+
}
106+
107+
// CollationCoercibility implements the interface sql.CollationCoercible
108+
func (e *ExportSet) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
109+
collation, coercibility = sql.GetCoercibility(ctx, e.on)
110+
otherCollation, otherCoercibility := sql.GetCoercibility(ctx, e.off)
111+
collation, coercibility = sql.ResolveCoercibility(collation, coercibility, otherCollation, otherCoercibility)
112+
if e.separator != nil {
113+
otherCollation, otherCoercibility = sql.GetCoercibility(ctx, e.separator)
114+
collation, coercibility = sql.ResolveCoercibility(collation, coercibility, otherCollation, otherCoercibility)
115+
}
116+
return collation, coercibility
117+
}
118+
119+
// String implements the Expression interface
120+
func (e *ExportSet) String() string {
121+
children := e.Children()
122+
childStrs := make([]string, len(children))
123+
for i, child := range children {
124+
childStrs[i] = child.String()
125+
}
126+
return fmt.Sprintf("export_set(%s)", strings.Join(childStrs, ", "))
127+
}
128+
129+
// WithChildren implements the Expression interface
130+
func (e *ExportSet) WithChildren(children ...sql.Expression) (sql.Expression, error) {
131+
return NewExportSet(children...)
132+
}
133+
134+
// Eval implements the Expression interface
135+
func (e *ExportSet) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
136+
bitsVal, err := e.bits.Eval(ctx, row)
137+
if err != nil {
138+
return nil, err
139+
}
140+
if bitsVal == nil {
141+
return nil, nil
142+
}
143+
144+
onVal, err := e.on.Eval(ctx, row)
145+
if err != nil {
146+
return nil, err
147+
}
148+
if onVal == nil {
149+
return nil, nil
150+
}
151+
152+
offVal, err := e.off.Eval(ctx, row)
153+
if err != nil {
154+
return nil, err
155+
}
156+
if offVal == nil {
157+
return nil, nil
158+
}
159+
160+
// Default separator is comma
161+
separatorVal := ","
162+
if e.separator != nil {
163+
sepVal, err := e.separator.Eval(ctx, row)
164+
if err != nil {
165+
return nil, err
166+
}
167+
if sepVal == nil {
168+
return nil, nil
169+
}
170+
sepStr, _, err := types.LongText.Convert(ctx, sepVal)
171+
if err != nil {
172+
return nil, err
173+
}
174+
separatorVal = sepStr.(string)
175+
}
176+
177+
// Default number of bits is 64
178+
numberOfBitsVal := int64(64)
179+
if e.numberOfBits != nil {
180+
numBitsVal, err := e.numberOfBits.Eval(ctx, row)
181+
if err != nil {
182+
return nil, err
183+
}
184+
if numBitsVal == nil {
185+
return nil, nil
186+
}
187+
numBitsInt, _, err := types.Int64.Convert(ctx, numBitsVal)
188+
if err != nil {
189+
return nil, err
190+
}
191+
numberOfBitsVal = numBitsInt.(int64)
192+
// MySQL silently clips to 64 if larger, treats negative as 64
193+
if numberOfBitsVal > 64 || numberOfBitsVal < 0 {
194+
numberOfBitsVal = 64
195+
}
196+
}
197+
198+
// Convert arguments to proper types
199+
bitsInt, _, err := types.Uint64.Convert(ctx, bitsVal)
200+
if err != nil {
201+
return nil, err
202+
}
203+
204+
onStr, _, err := types.LongText.Convert(ctx, onVal)
205+
if err != nil {
206+
return nil, err
207+
}
208+
209+
offStr, _, err := types.LongText.Convert(ctx, offVal)
210+
if err != nil {
211+
return nil, err
212+
}
213+
214+
bits := bitsInt.(uint64)
215+
on := onStr.(string)
216+
off := offStr.(string)
217+
218+
// Build the result by examining bits from right to left (LSB to MSB)
219+
// but adding strings from left to right
220+
result := make([]string, numberOfBitsVal)
221+
for i := int64(0); i < numberOfBitsVal; i++ {
222+
if (bits & (1 << uint(i))) != 0 {
223+
result[i] = on
224+
} else {
225+
result[i] = off
226+
}
227+
}
228+
229+
return strings.Join(result, separatorVal), nil
230+
}

0 commit comments

Comments
 (0)