Skip to content

Commit 7c4ada8

Browse files
authored
Fix formatting of *float64 parameters (#215)
Attempt to fix #214 Signed-off-by: Esdras Beleza <[email protected]>
1 parent e82880f commit 7c4ada8

File tree

4 files changed

+78
-49
lines changed

4 files changed

+78
-49
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Release History
22

3+
- Fix formatting of *float64 parameters
4+
35
## v1.5.4 (2024-04-10)
46

57
- Added OAuth support for GCP (databricks/databricks-sql-go#189 by @rcypher-databricks)

driver_e2e_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,10 @@ func strPtr(s string) *string {
504504
return &s
505505
}
506506

507+
func float64Ptr(f float64) *float64 {
508+
return &f
509+
}
510+
507511
func loadTestData(t *testing.T, name string, v any) {
508512
if f, err := os.ReadFile(fmt.Sprintf("testdata/%s", name)); err != nil {
509513
t.Errorf("could not read data from: %s", name)

parameter_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,14 @@ import (
1212

1313
func TestParameter_Inference(t *testing.T) {
1414
t.Run("Should infer types correctly", func(t *testing.T) {
15-
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: Parameter{Value: "6.2", Type: SqlDecimal}}}
15+
values := [6]driver.NamedValue{
16+
{Name: "", Value: float32(5.1)},
17+
{Name: "", Value: time.Now()},
18+
{Name: "", Value: int64(5)},
19+
{Name: "", Value: true},
20+
{Name: "", Value: Parameter{Value: "6.2", Type: SqlDecimal}},
21+
{Name: "", Value: Parameter{Value: float64Ptr(6.2), Type: SqlUnkown}},
22+
}
1623
parameters := convertNamedValuesToSparkParams(values[:])
1724
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
1825
assert.NotNil(t, parameters[1].Value.StringValue)
@@ -21,6 +28,7 @@ func TestParameter_Inference(t *testing.T) {
2128
assert.Equal(t, string("true"), *parameters[3].Value.StringValue)
2229
assert.Equal(t, string("DECIMAL(2,1)"), *parameters[4].Type)
2330
assert.Equal(t, string("6.2"), *parameters[4].Value.StringValue)
31+
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, parameters[5].Value)
2432
})
2533
}
2634
func TestParameters_Names(t *testing.T) {

parameters.go

Lines changed: 63 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package dbsql
33
import (
44
"database/sql/driver"
55
"fmt"
6+
"reflect"
67
"strconv"
78
"strings"
89
"time"
@@ -90,57 +91,71 @@ func inferTypes(params []Parameter) {
9091
for i := range params {
9192
param := &params[i]
9293
if param.Type == SqlUnkown {
93-
switch value := param.Value.(type) {
94-
case bool:
95-
param.Value = strconv.FormatBool(value)
96-
param.Type = SqlBoolean
97-
case string:
98-
param.Value = value
99-
param.Type = SqlString
100-
case int:
101-
param.Value = strconv.Itoa(value)
102-
param.Type = SqlInteger
103-
case uint:
104-
param.Value = strconv.FormatUint(uint64(value), 10)
105-
param.Type = SqlInteger
106-
case int8:
107-
param.Value = strconv.Itoa(int(value))
108-
param.Type = SqlInteger
109-
case uint8:
110-
param.Value = strconv.FormatUint(uint64(value), 10)
111-
param.Type = SqlInteger
112-
case int16:
113-
param.Value = strconv.Itoa(int(value))
114-
param.Type = SqlInteger
115-
case uint16:
116-
param.Value = strconv.FormatUint(uint64(value), 10)
117-
param.Type = SqlInteger
118-
case int32:
119-
param.Value = strconv.Itoa(int(value))
120-
param.Type = SqlInteger
121-
case uint32:
122-
param.Value = strconv.FormatUint(uint64(value), 10)
123-
param.Type = SqlInteger
124-
case int64:
125-
param.Value = strconv.Itoa(int(value))
126-
param.Type = SqlInteger
127-
case uint64:
128-
param.Value = strconv.FormatUint(uint64(value), 10)
129-
param.Type = SqlInteger
130-
case float32:
131-
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
132-
param.Type = SqlFloat
133-
case time.Time:
134-
param.Value = value.Format(time.RFC3339Nano)
135-
param.Type = SqlTimestamp
136-
default:
137-
s := fmt.Sprintf("%s", param.Value)
138-
param.Value = s
139-
param.Type = SqlString
140-
}
94+
inferType(param)
14195
}
14296
}
14397
}
98+
99+
func inferType(param *Parameter) {
100+
if param.Value != nil && reflect.ValueOf(param.Value).Kind() == reflect.Ptr {
101+
param.Value = reflect.ValueOf(param.Value).Elem().Interface()
102+
inferType(param)
103+
return
104+
}
105+
106+
switch value := param.Value.(type) {
107+
case bool:
108+
param.Value = strconv.FormatBool(value)
109+
param.Type = SqlBoolean
110+
case string:
111+
param.Value = value
112+
param.Type = SqlString
113+
case int:
114+
param.Value = strconv.Itoa(value)
115+
param.Type = SqlInteger
116+
case uint:
117+
param.Value = strconv.FormatUint(uint64(value), 10)
118+
param.Type = SqlInteger
119+
case int8:
120+
param.Value = strconv.Itoa(int(value))
121+
param.Type = SqlInteger
122+
case uint8:
123+
param.Value = strconv.FormatUint(uint64(value), 10)
124+
param.Type = SqlInteger
125+
case int16:
126+
param.Value = strconv.Itoa(int(value))
127+
param.Type = SqlInteger
128+
case uint16:
129+
param.Value = strconv.FormatUint(uint64(value), 10)
130+
param.Type = SqlInteger
131+
case int32:
132+
param.Value = strconv.Itoa(int(value))
133+
param.Type = SqlInteger
134+
case uint32:
135+
param.Value = strconv.FormatUint(uint64(value), 10)
136+
param.Type = SqlInteger
137+
case int64:
138+
param.Value = strconv.Itoa(int(value))
139+
param.Type = SqlInteger
140+
case uint64:
141+
param.Value = strconv.FormatUint(uint64(value), 10)
142+
param.Type = SqlInteger
143+
case float32:
144+
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
145+
param.Type = SqlFloat
146+
case float64:
147+
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 64)
148+
param.Type = SqlFloat
149+
case time.Time:
150+
param.Value = value.Format(time.RFC3339Nano)
151+
param.Type = SqlTimestamp
152+
default:
153+
s := fmt.Sprintf("%s", param.Value)
154+
param.Value = s
155+
param.Type = SqlString
156+
}
157+
}
158+
144159
func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.TSparkParameter {
145160
var sparkParams []*cli_service.TSparkParameter
146161

0 commit comments

Comments
 (0)