1
+ // Copyright 2020-2024 Dolthub, Inc.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ package function
16
+
17
+ import (
18
+ "fmt"
19
+ "strings"
20
+
21
+ "github.com/dolthub/go-mysql-server/sql"
22
+ "github.com/dolthub/go-mysql-server/sql/types"
23
+ )
24
+
25
+ // ExportSet implements the SQL function EXPORT_SET() which returns a string representation of bits in a number
26
+ type ExportSet struct {
27
+ bits sql.Expression
28
+ on sql.Expression
29
+ off sql.Expression
30
+ separator sql.Expression
31
+ numberOfBits sql.Expression
32
+ }
33
+
34
+ var _ sql.FunctionExpression = (* ExportSet )(nil )
35
+ var _ sql.CollationCoercible = (* ExportSet )(nil )
36
+
37
+ // NewExportSet creates a new ExportSet expression
38
+ func NewExportSet (args ... sql.Expression ) (sql.Expression , error ) {
39
+ if len (args ) < 3 || len (args ) > 5 {
40
+ return nil , sql .ErrInvalidArgumentNumber .New ("EXPORT_SET" , "3, 4, or 5" , len (args ))
41
+ }
42
+
43
+ var separator , numberOfBits sql.Expression
44
+ if len (args ) >= 4 {
45
+ separator = args [3 ]
46
+ }
47
+ if len (args ) == 5 {
48
+ numberOfBits = args [4 ]
49
+ }
50
+
51
+ return & ExportSet {
52
+ bits : args [0 ],
53
+ on : args [1 ],
54
+ off : args [2 ],
55
+ separator : separator ,
56
+ numberOfBits : numberOfBits ,
57
+ }, nil
58
+ }
59
+
60
+ // FunctionName implements sql.FunctionExpression
61
+ func (e * ExportSet ) FunctionName () string {
62
+ return "export_set"
63
+ }
64
+
65
+ // Description implements sql.FunctionExpression
66
+ func (e * ExportSet ) Description () string {
67
+ return "returns a string such that for every bit set in the value bits, you get an on string and for every unset bit, you get an off string."
68
+ }
69
+
70
+ // Children implements the Expression interface
71
+ func (e * ExportSet ) Children () []sql.Expression {
72
+ children := []sql.Expression {e .bits , e .on , e .off }
73
+ if e .separator != nil {
74
+ children = append (children , e .separator )
75
+ }
76
+ if e .numberOfBits != nil {
77
+ children = append (children , e .numberOfBits )
78
+ }
79
+ return children
80
+ }
81
+
82
+ // Resolved implements the Expression interface
83
+ func (e * ExportSet ) Resolved () bool {
84
+ for _ , child := range e .Children () {
85
+ if ! child .Resolved () {
86
+ return false
87
+ }
88
+ }
89
+ return true
90
+ }
91
+
92
+ // IsNullable implements the Expression interface
93
+ func (e * ExportSet ) IsNullable () bool {
94
+ for _ , child := range e .Children () {
95
+ if child .IsNullable () {
96
+ return true
97
+ }
98
+ }
99
+ return false
100
+ }
101
+
102
+ // Type implements the Expression interface
103
+ func (e * ExportSet ) Type () sql.Type {
104
+ return types .LongText
105
+ }
106
+
107
+ // CollationCoercibility implements the interface sql.CollationCoercible
108
+ func (e * ExportSet ) CollationCoercibility (ctx * sql.Context ) (collation sql.CollationID , coercibility byte ) {
109
+ collation , coercibility = sql .GetCoercibility (ctx , e .on )
110
+ otherCollation , otherCoercibility := sql .GetCoercibility (ctx , e .off )
111
+ collation , coercibility = sql .ResolveCoercibility (collation , coercibility , otherCollation , otherCoercibility )
112
+ if e .separator != nil {
113
+ otherCollation , otherCoercibility = sql .GetCoercibility (ctx , e .separator )
114
+ collation , coercibility = sql .ResolveCoercibility (collation , coercibility , otherCollation , otherCoercibility )
115
+ }
116
+ return collation , coercibility
117
+ }
118
+
119
+ // String implements the Expression interface
120
+ func (e * ExportSet ) String () string {
121
+ children := e .Children ()
122
+ childStrs := make ([]string , len (children ))
123
+ for i , child := range children {
124
+ childStrs [i ] = child .String ()
125
+ }
126
+ return fmt .Sprintf ("export_set(%s)" , strings .Join (childStrs , ", " ))
127
+ }
128
+
129
+ // WithChildren implements the Expression interface
130
+ func (e * ExportSet ) WithChildren (children ... sql.Expression ) (sql.Expression , error ) {
131
+ return NewExportSet (children ... )
132
+ }
133
+
134
+ // Eval implements the Expression interface
135
+ func (e * ExportSet ) Eval (ctx * sql.Context , row sql.Row ) (interface {}, error ) {
136
+ bitsVal , err := e .bits .Eval (ctx , row )
137
+ if err != nil {
138
+ return nil , err
139
+ }
140
+ if bitsVal == nil {
141
+ return nil , nil
142
+ }
143
+
144
+ onVal , err := e .on .Eval (ctx , row )
145
+ if err != nil {
146
+ return nil , err
147
+ }
148
+ if onVal == nil {
149
+ return nil , nil
150
+ }
151
+
152
+ offVal , err := e .off .Eval (ctx , row )
153
+ if err != nil {
154
+ return nil , err
155
+ }
156
+ if offVal == nil {
157
+ return nil , nil
158
+ }
159
+
160
+ // Default separator is comma
161
+ separatorVal := ","
162
+ if e .separator != nil {
163
+ sepVal , err := e .separator .Eval (ctx , row )
164
+ if err != nil {
165
+ return nil , err
166
+ }
167
+ if sepVal == nil {
168
+ return nil , nil
169
+ }
170
+ sepStr , _ , err := types .LongText .Convert (ctx , sepVal )
171
+ if err != nil {
172
+ return nil , err
173
+ }
174
+ separatorVal = sepStr .(string )
175
+ }
176
+
177
+ // Default number of bits is 64
178
+ numberOfBitsVal := int64 (64 )
179
+ if e .numberOfBits != nil {
180
+ numBitsVal , err := e .numberOfBits .Eval (ctx , row )
181
+ if err != nil {
182
+ return nil , err
183
+ }
184
+ if numBitsVal == nil {
185
+ return nil , nil
186
+ }
187
+ numBitsInt , _ , err := types .Int64 .Convert (ctx , numBitsVal )
188
+ if err != nil {
189
+ return nil , err
190
+ }
191
+ numberOfBitsVal = numBitsInt .(int64 )
192
+ // MySQL silently clips to 64 if larger, treats negative as 64
193
+ if numberOfBitsVal > 64 || numberOfBitsVal < 0 {
194
+ numberOfBitsVal = 64
195
+ }
196
+ }
197
+
198
+ // Convert arguments to proper types
199
+ bitsInt , _ , err := types .Uint64 .Convert (ctx , bitsVal )
200
+ if err != nil {
201
+ return nil , err
202
+ }
203
+
204
+ onStr , _ , err := types .LongText .Convert (ctx , onVal )
205
+ if err != nil {
206
+ return nil , err
207
+ }
208
+
209
+ offStr , _ , err := types .LongText .Convert (ctx , offVal )
210
+ if err != nil {
211
+ return nil , err
212
+ }
213
+
214
+ bits := bitsInt .(uint64 )
215
+ on := onStr .(string )
216
+ off := offStr .(string )
217
+
218
+ // Build the result by examining bits from right to left (LSB to MSB)
219
+ // but adding strings from left to right
220
+ result := make ([]string , numberOfBitsVal )
221
+ for i := int64 (0 ); i < numberOfBitsVal ; i ++ {
222
+ if (bits & (1 << uint (i ))) != 0 {
223
+ result [i ] = on
224
+ } else {
225
+ result [i ] = off
226
+ }
227
+ }
228
+
229
+ return strings .Join (result , separatorVal ), nil
230
+ }
0 commit comments