1+ // Copyright 2020-2024 Dolthub, Inc.
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // http://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+
15+ package function
16+
17+ import (
18+ "fmt"
19+ "strings"
20+
21+ "github.com/dolthub/go-mysql-server/sql"
22+ "github.com/dolthub/go-mysql-server/sql/types"
23+ )
24+
25+ // ExportSet implements the SQL function EXPORT_SET() which returns a string representation of bits in a number
26+ type ExportSet struct {
27+ bits sql.Expression
28+ on sql.Expression
29+ off sql.Expression
30+ separator sql.Expression
31+ numberOfBits sql.Expression
32+ }
33+
34+ var _ sql.FunctionExpression = (* ExportSet )(nil )
35+ var _ sql.CollationCoercible = (* ExportSet )(nil )
36+
37+ // NewExportSet creates a new ExportSet expression
38+ func NewExportSet (args ... sql.Expression ) (sql.Expression , error ) {
39+ if len (args ) < 3 || len (args ) > 5 {
40+ return nil , sql .ErrInvalidArgumentNumber .New ("EXPORT_SET" , "3, 4, or 5" , len (args ))
41+ }
42+
43+ var separator , numberOfBits sql.Expression
44+ if len (args ) >= 4 {
45+ separator = args [3 ]
46+ }
47+ if len (args ) == 5 {
48+ numberOfBits = args [4 ]
49+ }
50+
51+ return & ExportSet {
52+ bits : args [0 ],
53+ on : args [1 ],
54+ off : args [2 ],
55+ separator : separator ,
56+ numberOfBits : numberOfBits ,
57+ }, nil
58+ }
59+
60+ // FunctionName implements sql.FunctionExpression
61+ func (e * ExportSet ) FunctionName () string {
62+ return "export_set"
63+ }
64+
65+ // Description implements sql.FunctionExpression
66+ func (e * ExportSet ) Description () string {
67+ return "returns a string such that for every bit set in the value bits, you get an on string and for every unset bit, you get an off string."
68+ }
69+
70+ // Children implements the Expression interface
71+ func (e * ExportSet ) Children () []sql.Expression {
72+ children := []sql.Expression {e .bits , e .on , e .off }
73+ if e .separator != nil {
74+ children = append (children , e .separator )
75+ }
76+ if e .numberOfBits != nil {
77+ children = append (children , e .numberOfBits )
78+ }
79+ return children
80+ }
81+
82+ // Resolved implements the Expression interface
83+ func (e * ExportSet ) Resolved () bool {
84+ for _ , child := range e .Children () {
85+ if ! child .Resolved () {
86+ return false
87+ }
88+ }
89+ return true
90+ }
91+
92+ // IsNullable implements the Expression interface
93+ func (e * ExportSet ) IsNullable () bool {
94+ for _ , child := range e .Children () {
95+ if child .IsNullable () {
96+ return true
97+ }
98+ }
99+ return false
100+ }
101+
102+ // Type implements the Expression interface
103+ func (e * ExportSet ) Type () sql.Type {
104+ return types .LongText
105+ }
106+
107+ // CollationCoercibility implements the interface sql.CollationCoercible
108+ func (e * ExportSet ) CollationCoercibility (ctx * sql.Context ) (collation sql.CollationID , coercibility byte ) {
109+ collation , coercibility = sql .GetCoercibility (ctx , e .on )
110+ otherCollation , otherCoercibility := sql .GetCoercibility (ctx , e .off )
111+ collation , coercibility = sql .ResolveCoercibility (collation , coercibility , otherCollation , otherCoercibility )
112+ if e .separator != nil {
113+ otherCollation , otherCoercibility = sql .GetCoercibility (ctx , e .separator )
114+ collation , coercibility = sql .ResolveCoercibility (collation , coercibility , otherCollation , otherCoercibility )
115+ }
116+ return collation , coercibility
117+ }
118+
119+ // String implements the Expression interface
120+ func (e * ExportSet ) String () string {
121+ children := e .Children ()
122+ childStrs := make ([]string , len (children ))
123+ for i , child := range children {
124+ childStrs [i ] = child .String ()
125+ }
126+ return fmt .Sprintf ("export_set(%s)" , strings .Join (childStrs , ", " ))
127+ }
128+
129+ // WithChildren implements the Expression interface
130+ func (e * ExportSet ) WithChildren (children ... sql.Expression ) (sql.Expression , error ) {
131+ return NewExportSet (children ... )
132+ }
133+
134+ // Eval implements the Expression interface
135+ func (e * ExportSet ) Eval (ctx * sql.Context , row sql.Row ) (interface {}, error ) {
136+ bitsVal , err := e .bits .Eval (ctx , row )
137+ if err != nil {
138+ return nil , err
139+ }
140+ if bitsVal == nil {
141+ return nil , nil
142+ }
143+
144+ onVal , err := e .on .Eval (ctx , row )
145+ if err != nil {
146+ return nil , err
147+ }
148+ if onVal == nil {
149+ return nil , nil
150+ }
151+
152+ offVal , err := e .off .Eval (ctx , row )
153+ if err != nil {
154+ return nil , err
155+ }
156+ if offVal == nil {
157+ return nil , nil
158+ }
159+
160+ // Default separator is comma
161+ separatorVal := ","
162+ if e .separator != nil {
163+ sepVal , err := e .separator .Eval (ctx , row )
164+ if err != nil {
165+ return nil , err
166+ }
167+ if sepVal == nil {
168+ return nil , nil
169+ }
170+ sepStr , _ , err := types .LongText .Convert (ctx , sepVal )
171+ if err != nil {
172+ return nil , err
173+ }
174+ separatorVal = sepStr .(string )
175+ }
176+
177+ // Default number of bits is 64
178+ numberOfBitsVal := int64 (64 )
179+ if e .numberOfBits != nil {
180+ numBitsVal , err := e .numberOfBits .Eval (ctx , row )
181+ if err != nil {
182+ return nil , err
183+ }
184+ if numBitsVal == nil {
185+ return nil , nil
186+ }
187+ numBitsInt , _ , err := types .Int64 .Convert (ctx , numBitsVal )
188+ if err != nil {
189+ return nil , err
190+ }
191+ numberOfBitsVal = numBitsInt .(int64 )
192+ // MySQL silently clips to 64 if larger, treats negative as 64
193+ if numberOfBitsVal > 64 || numberOfBitsVal < 0 {
194+ numberOfBitsVal = 64
195+ }
196+ }
197+
198+ // Convert arguments to proper types
199+ bitsInt , _ , err := types .Uint64 .Convert (ctx , bitsVal )
200+ if err != nil {
201+ return nil , err
202+ }
203+
204+ onStr , _ , err := types .LongText .Convert (ctx , onVal )
205+ if err != nil {
206+ return nil , err
207+ }
208+
209+ offStr , _ , err := types .LongText .Convert (ctx , offVal )
210+ if err != nil {
211+ return nil , err
212+ }
213+
214+ bits := bitsInt .(uint64 )
215+ on := onStr .(string )
216+ off := offStr .(string )
217+
218+ // Build the result by examining bits from right to left (LSB to MSB)
219+ // but adding strings from left to right
220+ result := make ([]string , numberOfBitsVal )
221+ for i := int64 (0 ); i < numberOfBitsVal ; i ++ {
222+ if (bits & (1 << uint (i ))) != 0 {
223+ result [i ] = on
224+ } else {
225+ result [i ] = off
226+ }
227+ }
228+
229+ return strings .Join (result , separatorVal ), nil
230+ }
0 commit comments