Skip to content

Commit e95dd4a

Browse files
authored
[PECO-1048] Add example for parameterized queries (#168)
We don't currently have an example for parameterized queries, we should add one.
2 parents acdb8ba + 1ba2c5e commit e95dd4a

File tree

4 files changed

+167
-93
lines changed

4 files changed

+167
-93
lines changed

connection.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,9 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
596596

597597
func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
598598
var err error
599-
if dbsqlParam, ok := nv.Value.(DBSqlParam); ok {
600-
nv.Name = dbsqlParam.Name
601-
dbsqlParam.Value, err = driver.DefaultParameterConverter.ConvertValue(dbsqlParam.Value)
599+
if parameter, ok := nv.Value.(Parameter); ok {
600+
nv.Name = parameter.Name
601+
parameter.Value, err = driver.DefaultParameterConverter.ConvertValue(parameter.Value)
602602
return err
603603
}
604604

examples/parameters/main.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"fmt"
7+
"log"
8+
"os"
9+
"strconv"
10+
11+
dbsql "github.com/databricks/databricks-sql-go"
12+
"github.com/joho/godotenv"
13+
)
14+
15+
func main() {
16+
// Opening a driver typically will not attempt to connect to the database.
17+
err := godotenv.Load()
18+
19+
if err != nil {
20+
log.Fatal(err.Error())
21+
}
22+
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
23+
if err != nil {
24+
log.Fatal(err.Error())
25+
}
26+
connector, err := dbsql.NewConnector(
27+
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
28+
dbsql.WithPort(port),
29+
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
30+
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
31+
)
32+
if err != nil {
33+
// This will not be a connection error, but a DSN parse error or
34+
// another initialization error.
35+
log.Fatal(err)
36+
}
37+
db := sql.OpenDB(connector)
38+
defer db.Close()
39+
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
40+
// defer cancel()
41+
ctx := context.Background()
42+
var p_bool bool
43+
var p_int int
44+
var p_double float64
45+
var p_float float32
46+
var p_date string
47+
err1 := db.QueryRowContext(ctx, `SELECT
48+
:p_bool AS col_bool,
49+
:p_int AS col_int,
50+
:p_double AS col_double,
51+
:p_float AS col_float,
52+
:p_date AS col_date`,
53+
dbsql.Parameter{Name: "p_bool", Value: true},
54+
dbsql.Parameter{Name: "p_int", Value: int(1234)},
55+
dbsql.Parameter{Name: "p_double", Type: dbsql.SqlDouble, Value: "3.14"},
56+
dbsql.Parameter{Name: "p_float", Type: dbsql.SqlFloat, Value: "3.14"},
57+
dbsql.Parameter{Name: "p_date", Type: dbsql.SqlDate, Value: "2017-07-23 00:00:00"}).Scan(&p_bool, &p_int, &p_double, &p_float, &p_date)
58+
59+
if err1 != nil {
60+
if err1 == sql.ErrNoRows {
61+
fmt.Println("not found")
62+
return
63+
} else {
64+
fmt.Printf("err: %v\n", err1)
65+
}
66+
}
67+
68+
}

parameter_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212

1313
func TestParameter_Inference(t *testing.T) {
1414
t.Run("Should infer types correctly", func(t *testing.T) {
15-
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: DBSqlParam{Value: "6.2", Type: Decimal}}}
15+
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: Parameter{Value: "6.2", Type: SqlDecimal}}}
1616
parameters := convertNamedValuesToSparkParams(values[:])
1717
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
1818
assert.NotNil(t, parameters[1].Value.StringValue)
@@ -25,7 +25,7 @@ func TestParameter_Inference(t *testing.T) {
2525
}
2626
func TestParameters_Names(t *testing.T) {
2727
t.Run("Should infer types correctly", func(t *testing.T) {
28-
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: DBSqlParam{Name: "2", Type: Decimal, Value: "6.2"}}}
28+
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: Parameter{Name: "2", Type: SqlDecimal, Value: "6.2"}}}
2929
parameters := convertNamedValuesToSparkParams(values[:])
3030
assert.Equal(t, string("1"), *parameters[0].Name)
3131
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)

parameters.go

Lines changed: 94 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -10,141 +10,147 @@ import (
1010
"github.com/databricks/databricks-sql-go/internal/cli_service"
1111
)
1212

13-
type DBSqlParam struct {
13+
type Parameter struct {
1414
Name string
1515
Type SqlType
1616
Value any
1717
}
1818

19-
type SqlType int64
19+
type SqlType int
2020

2121
const (
22-
String SqlType = iota
23-
Date
24-
Timestamp
25-
Float
26-
Decimal
27-
Double
28-
Integer
29-
BigInt
30-
SmallInt
31-
TinyInt
32-
Boolean
33-
IntervalMonth
34-
IntervalDay
22+
SqlUnkown SqlType = iota
23+
SqlString
24+
SqlDate
25+
SqlTimestamp
26+
SqlFloat
27+
SqlDecimal
28+
SqlDouble
29+
SqlInteger
30+
SqlBigInt
31+
SqlSmallInt
32+
SqlTinyInt
33+
SqlBoolean
34+
SqlIntervalMonth
35+
SqlIntervalDay
3536
)
3637

3738
func (s SqlType) String() string {
3839
switch s {
39-
case String:
40+
case SqlString:
4041
return "STRING"
41-
case Date:
42+
case SqlDate:
4243
return "DATE"
43-
case Timestamp:
44+
case SqlTimestamp:
4445
return "TIMESTAMP"
45-
case Float:
46+
case SqlFloat:
4647
return "FLOAT"
47-
case Decimal:
48+
case SqlDecimal:
4849
return "DECIMAL"
49-
case Double:
50+
case SqlDouble:
5051
return "DOUBLE"
51-
case Integer:
52+
case SqlInteger:
5253
return "INTEGER"
53-
case BigInt:
54+
case SqlBigInt:
5455
return "BIGINT"
55-
case SmallInt:
56+
case SqlSmallInt:
5657
return "SMALLINT"
57-
case TinyInt:
58+
case SqlTinyInt:
5859
return "TINYINT"
59-
case Boolean:
60+
case SqlBoolean:
6061
return "BOOLEAN"
61-
case IntervalMonth:
62+
case SqlIntervalMonth:
6263
return "INTERVAL MONTH"
63-
case IntervalDay:
64+
case SqlIntervalDay:
6465
return "INTERVAL DAY"
6566
}
6667
return "unknown"
6768
}
6869

69-
func valuesToDBSQLParams(namedValues []driver.NamedValue) []DBSqlParam {
70-
var params []DBSqlParam
70+
func valuesToParameters(namedValues []driver.NamedValue) []Parameter {
71+
var params []Parameter
7172
for i := range namedValues {
73+
newParam := *new(Parameter)
7274
namedValue := namedValues[i]
73-
param := *new(DBSqlParam)
74-
param.Name = namedValue.Name
75-
param.Value = namedValue.Value
76-
params = append(params, param)
75+
param, ok := namedValue.Value.(Parameter)
76+
if ok {
77+
newParam.Name = param.Name
78+
newParam.Value = param.Value
79+
newParam.Type = param.Type
80+
} else {
81+
newParam.Name = namedValue.Name
82+
newParam.Value = namedValue.Value
83+
}
84+
params = append(params, newParam)
7785
}
7886
return params
7987
}
8088

81-
func inferTypes(params []DBSqlParam) {
89+
func inferTypes(params []Parameter) {
8290
for i := range params {
8391
param := &params[i]
84-
switch value := param.Value.(type) {
85-
case bool:
86-
param.Value = strconv.FormatBool(value)
87-
param.Type = Boolean
88-
case string:
89-
param.Value = value
90-
param.Type = String
91-
case int:
92-
param.Value = strconv.Itoa(value)
93-
param.Type = Integer
94-
case uint:
95-
param.Value = strconv.FormatUint(uint64(value), 10)
96-
param.Type = Integer
97-
case int8:
98-
param.Value = strconv.Itoa(int(value))
99-
param.Type = Integer
100-
case uint8:
101-
param.Value = strconv.FormatUint(uint64(value), 10)
102-
param.Type = Integer
103-
case int16:
104-
param.Value = strconv.Itoa(int(value))
105-
param.Type = Integer
106-
case uint16:
107-
param.Value = strconv.FormatUint(uint64(value), 10)
108-
param.Type = Integer
109-
case int32:
110-
param.Value = strconv.Itoa(int(value))
111-
param.Type = Integer
112-
case uint32:
113-
param.Value = strconv.FormatUint(uint64(value), 10)
114-
param.Type = Integer
115-
case int64:
116-
param.Value = strconv.Itoa(int(value))
117-
param.Type = Integer
118-
case uint64:
119-
param.Value = strconv.FormatUint(uint64(value), 10)
120-
param.Type = Integer
121-
case float32:
122-
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
123-
param.Type = Float
124-
case time.Time:
125-
param.Value = value.String()
126-
param.Type = Timestamp
127-
case DBSqlParam:
128-
param.Name = value.Name
129-
param.Value = value.Value
130-
param.Type = value.Type
131-
default:
132-
s := fmt.Sprintf("%s", value)
133-
param.Value = s
134-
param.Type = String
92+
if param.Type == SqlUnkown {
93+
switch value := param.Value.(type) {
94+
case bool:
95+
param.Value = strconv.FormatBool(value)
96+
param.Type = SqlBoolean
97+
case string:
98+
param.Value = value
99+
param.Type = SqlString
100+
case int:
101+
param.Value = strconv.Itoa(value)
102+
param.Type = SqlInteger
103+
case uint:
104+
param.Value = strconv.FormatUint(uint64(value), 10)
105+
param.Type = SqlInteger
106+
case int8:
107+
param.Value = strconv.Itoa(int(value))
108+
param.Type = SqlInteger
109+
case uint8:
110+
param.Value = strconv.FormatUint(uint64(value), 10)
111+
param.Type = SqlInteger
112+
case int16:
113+
param.Value = strconv.Itoa(int(value))
114+
param.Type = SqlInteger
115+
case uint16:
116+
param.Value = strconv.FormatUint(uint64(value), 10)
117+
param.Type = SqlInteger
118+
case int32:
119+
param.Value = strconv.Itoa(int(value))
120+
param.Type = SqlInteger
121+
case uint32:
122+
param.Value = strconv.FormatUint(uint64(value), 10)
123+
param.Type = SqlInteger
124+
case int64:
125+
param.Value = strconv.Itoa(int(value))
126+
param.Type = SqlInteger
127+
case uint64:
128+
param.Value = strconv.FormatUint(uint64(value), 10)
129+
param.Type = SqlInteger
130+
case float32:
131+
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
132+
param.Type = SqlFloat
133+
case time.Time:
134+
param.Value = value.String()
135+
param.Type = SqlTimestamp
136+
default:
137+
s := fmt.Sprintf("%s", param.Value)
138+
param.Value = s
139+
param.Type = SqlString
140+
}
135141
}
136142
}
137143
}
138144
func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.TSparkParameter {
139145
var sparkParams []*cli_service.TSparkParameter
140146

141-
sqlParams := valuesToDBSQLParams(values)
147+
sqlParams := valuesToParameters(values)
142148
inferTypes(sqlParams)
143149
for i := range sqlParams {
144150
sqlParam := sqlParams[i]
145151
sparkParamValue := sqlParam.Value.(string)
146152
var sparkParamType string
147-
if sqlParam.Type == Decimal {
153+
if sqlParam.Type == SqlDecimal {
148154
sparkParamType = inferDecimalType(sparkParamValue)
149155
} else {
150156
sparkParamType = sqlParam.Type.String()

0 commit comments

Comments
 (0)