@@ -10,141 +10,147 @@ import (
1010 "github.com/databricks/databricks-sql-go/internal/cli_service"
1111)
1212
13- type DBSqlParam struct {
13+ type Parameter struct {
1414 Name string
1515 Type SqlType
1616 Value any
1717}
1818
19- type SqlType int64
19+ type SqlType int
2020
2121const (
22- String SqlType = iota
23- Date
24- Timestamp
25- Float
26- Decimal
27- Double
28- Integer
29- BigInt
30- SmallInt
31- TinyInt
32- Boolean
33- IntervalMonth
34- IntervalDay
22+ SqlUnkown SqlType = iota
23+ SqlString
24+ SqlDate
25+ SqlTimestamp
26+ SqlFloat
27+ SqlDecimal
28+ SqlDouble
29+ SqlInteger
30+ SqlBigInt
31+ SqlSmallInt
32+ SqlTinyInt
33+ SqlBoolean
34+ SqlIntervalMonth
35+ SqlIntervalDay
3536)
3637
3738func (s SqlType ) String () string {
3839 switch s {
39- case String :
40+ case SqlString :
4041 return "STRING"
41- case Date :
42+ case SqlDate :
4243 return "DATE"
43- case Timestamp :
44+ case SqlTimestamp :
4445 return "TIMESTAMP"
45- case Float :
46+ case SqlFloat :
4647 return "FLOAT"
47- case Decimal :
48+ case SqlDecimal :
4849 return "DECIMAL"
49- case Double :
50+ case SqlDouble :
5051 return "DOUBLE"
51- case Integer :
52+ case SqlInteger :
5253 return "INTEGER"
53- case BigInt :
54+ case SqlBigInt :
5455 return "BIGINT"
55- case SmallInt :
56+ case SqlSmallInt :
5657 return "SMALLINT"
57- case TinyInt :
58+ case SqlTinyInt :
5859 return "TINYINT"
59- case Boolean :
60+ case SqlBoolean :
6061 return "BOOLEAN"
61- case IntervalMonth :
62+ case SqlIntervalMonth :
6263 return "INTERVAL MONTH"
63- case IntervalDay :
64+ case SqlIntervalDay :
6465 return "INTERVAL DAY"
6566 }
6667 return "unknown"
6768}
6869
69- func valuesToDBSQLParams (namedValues []driver.NamedValue ) []DBSqlParam {
70- var params []DBSqlParam
70+ func valuesToParameters (namedValues []driver.NamedValue ) []Parameter {
71+ var params []Parameter
7172 for i := range namedValues {
73+ newParam := * new (Parameter )
7274 namedValue := namedValues [i ]
73- param := * new (DBSqlParam )
74- param .Name = namedValue .Name
75- param .Value = namedValue .Value
76- params = append (params , param )
75+ param , ok := namedValue .Value .(Parameter )
76+ if ok {
77+ newParam .Name = param .Name
78+ newParam .Value = param .Value
79+ newParam .Type = param .Type
80+ } else {
81+ newParam .Name = namedValue .Name
82+ newParam .Value = namedValue .Value
83+ }
84+ params = append (params , newParam )
7785 }
7886 return params
7987}
8088
81- func inferTypes (params []DBSqlParam ) {
89+ func inferTypes (params []Parameter ) {
8290 for i := range params {
8391 param := & params [i ]
84- switch value := param .Value .(type ) {
85- case bool :
86- param .Value = strconv .FormatBool (value )
87- param .Type = Boolean
88- case string :
89- param .Value = value
90- param .Type = String
91- case int :
92- param .Value = strconv .Itoa (value )
93- param .Type = Integer
94- case uint :
95- param .Value = strconv .FormatUint (uint64 (value ), 10 )
96- param .Type = Integer
97- case int8 :
98- param .Value = strconv .Itoa (int (value ))
99- param .Type = Integer
100- case uint8 :
101- param .Value = strconv .FormatUint (uint64 (value ), 10 )
102- param .Type = Integer
103- case int16 :
104- param .Value = strconv .Itoa (int (value ))
105- param .Type = Integer
106- case uint16 :
107- param .Value = strconv .FormatUint (uint64 (value ), 10 )
108- param .Type = Integer
109- case int32 :
110- param .Value = strconv .Itoa (int (value ))
111- param .Type = Integer
112- case uint32 :
113- param .Value = strconv .FormatUint (uint64 (value ), 10 )
114- param .Type = Integer
115- case int64 :
116- param .Value = strconv .Itoa (int (value ))
117- param .Type = Integer
118- case uint64 :
119- param .Value = strconv .FormatUint (uint64 (value ), 10 )
120- param .Type = Integer
121- case float32 :
122- param .Value = strconv .FormatFloat (float64 (value ), 'f' , - 1 , 32 )
123- param .Type = Float
124- case time.Time :
125- param .Value = value .String ()
126- param .Type = Timestamp
127- case DBSqlParam :
128- param .Name = value .Name
129- param .Value = value .Value
130- param .Type = value .Type
131- default :
132- s := fmt .Sprintf ("%s" , value )
133- param .Value = s
134- param .Type = String
92+ if param .Type == SqlUnkown {
93+ switch value := param .Value .(type ) {
94+ case bool :
95+ param .Value = strconv .FormatBool (value )
96+ param .Type = SqlBoolean
97+ case string :
98+ param .Value = value
99+ param .Type = SqlString
100+ case int :
101+ param .Value = strconv .Itoa (value )
102+ param .Type = SqlInteger
103+ case uint :
104+ param .Value = strconv .FormatUint (uint64 (value ), 10 )
105+ param .Type = SqlInteger
106+ case int8 :
107+ param .Value = strconv .Itoa (int (value ))
108+ param .Type = SqlInteger
109+ case uint8 :
110+ param .Value = strconv .FormatUint (uint64 (value ), 10 )
111+ param .Type = SqlInteger
112+ case int16 :
113+ param .Value = strconv .Itoa (int (value ))
114+ param .Type = SqlInteger
115+ case uint16 :
116+ param .Value = strconv .FormatUint (uint64 (value ), 10 )
117+ param .Type = SqlInteger
118+ case int32 :
119+ param .Value = strconv .Itoa (int (value ))
120+ param .Type = SqlInteger
121+ case uint32 :
122+ param .Value = strconv .FormatUint (uint64 (value ), 10 )
123+ param .Type = SqlInteger
124+ case int64 :
125+ param .Value = strconv .Itoa (int (value ))
126+ param .Type = SqlInteger
127+ case uint64 :
128+ param .Value = strconv .FormatUint (uint64 (value ), 10 )
129+ param .Type = SqlInteger
130+ case float32 :
131+ param .Value = strconv .FormatFloat (float64 (value ), 'f' , - 1 , 32 )
132+ param .Type = SqlFloat
133+ case time.Time :
134+ param .Value = value .String ()
135+ param .Type = SqlTimestamp
136+ default :
137+ s := fmt .Sprintf ("%s" , param .Value )
138+ param .Value = s
139+ param .Type = SqlString
140+ }
135141 }
136142 }
137143}
138144func convertNamedValuesToSparkParams (values []driver.NamedValue ) []* cli_service.TSparkParameter {
139145 var sparkParams []* cli_service.TSparkParameter
140146
141- sqlParams := valuesToDBSQLParams (values )
147+ sqlParams := valuesToParameters (values )
142148 inferTypes (sqlParams )
143149 for i := range sqlParams {
144150 sqlParam := sqlParams [i ]
145151 sparkParamValue := sqlParam .Value .(string )
146152 var sparkParamType string
147- if sqlParam .Type == Decimal {
153+ if sqlParam .Type == SqlDecimal {
148154 sparkParamType = inferDecimalType (sparkParamValue )
149155 } else {
150156 sparkParamType = sqlParam .Type .String ()
0 commit comments