Skip to content

Commit 250160b

Browse files
authored
[PECO-1112] Added decimal handling (#167)
We need to dynamically set the actual values of decimals, and this should be the smallest value that could hypothetically encompass the decimal string.
2 parents c08cf71 + c20e62b commit 250160b

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

parameter_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func TestParameter_Inference(t *testing.T) {
1919
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
2020
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("5")}, parameters[2].Value)
2121
assert.Equal(t, string("true"), *parameters[3].Value.StringValue)
22-
assert.Equal(t, string("DECIMAL"), *parameters[4].Type)
22+
assert.Equal(t, string("DECIMAL(2,1)"), *parameters[4].Type)
2323
assert.Equal(t, string("6.2"), *parameters[4].Value.StringValue)
2424
})
2525
}
@@ -31,6 +31,6 @@ func TestParameters_Names(t *testing.T) {
3131
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
3232
assert.Equal(t, string("2"), *parameters[1].Name)
3333
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
34-
assert.Equal(t, string("DECIMAL"), *parameters[1].Type)
34+
assert.Equal(t, string("DECIMAL(2,1)"), *parameters[1].Type)
3535
})
3636
}

parameters.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"database/sql/driver"
55
"fmt"
66
"strconv"
7+
"strings"
78
"time"
89

910
"github.com/databricks/databricks-sql-go/internal/cli_service"
@@ -142,9 +143,33 @@ func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.
142143
for i := range sqlParams {
143144
sqlParam := sqlParams[i]
144145
sparkParamValue := sqlParam.Value.(string)
145-
sparkParamType := sqlParam.Type.String()
146+
var sparkParamType string
147+
if sqlParam.Type == Decimal {
148+
sparkParamType = inferDecimalType(sparkParamValue)
149+
} else {
150+
sparkParamType = sqlParam.Type.String()
151+
}
146152
sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: &cli_service.TSparkParameterValue{StringValue: &sparkParamValue}}
147153
sparkParams = append(sparkParams, &sparkParam)
148154
}
149155
return sparkParams
150156
}
157+
158+
func inferDecimalType(d string) (t string) {
159+
var overall int
160+
var after int
161+
if strings.HasPrefix(d, "0.") {
162+
// Less than one
163+
overall = len(d) - 2
164+
after = len(d) - 2
165+
} else if !strings.Contains(d, ".") {
166+
// Less than one
167+
overall = len(d)
168+
after = 0
169+
} else {
170+
components := strings.Split(d, ".")
171+
overall, after = len(components[0])+len(components[1]), len(components[1])
172+
}
173+
174+
return fmt.Sprintf("DECIMAL(%d,%d)", overall, after)
175+
}

0 commit comments

Comments
 (0)