Skip to content

Commit 5a3a210

Browse files
authored
[PECO-1015] Add support for staging operations to Go Driver (#164)
We need to add support for staging operations to the go driver. This PR enabled Get, Delete, and Put operations.
2 parents 86525e6 + 625f6b4 commit 5a3a210

File tree

9 files changed

+335
-2
lines changed

9 files changed

+335
-2
lines changed

connection.go

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
package dbsql
22

33
import (
4+
"bytes"
45
"context"
56
"database/sql/driver"
7+
"encoding/json"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"os"
12+
"path/filepath"
13+
"strings"
614
"time"
715

816
"github.com/databricks/databricks-sql-go/driverctx"
@@ -94,6 +102,7 @@ func (c *conn) IsValid() bool {
94102
// ExecContext honors the context timeout and return when it is canceled.
95103
// Statement ExecContext is the same as connection ExecContext
96104
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
105+
97106
corrId := driverctx.CorrelationIdFromContext(ctx)
98107
log := logger.WithContext(c.id, corrId, "")
99108
msg, start := logger.Track("ExecContext")
@@ -108,6 +117,34 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
108117
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)
109118

110119
if exStmtResp != nil && exStmtResp.OperationHandle != nil {
120+
var isStagingOperation bool
121+
if exStmtResp.DirectResults != nil && exStmtResp.DirectResults.ResultSetMetadata != nil && exStmtResp.DirectResults.ResultSetMetadata.IsStagingOperation != nil {
122+
isStagingOperation = *exStmtResp.DirectResults.ResultSetMetadata.IsStagingOperation
123+
} else {
124+
req := cli_service.TGetResultSetMetadataReq{
125+
OperationHandle: exStmtResp.OperationHandle,
126+
}
127+
resp, err := c.client.GetResultSetMetadata(ctx, &req)
128+
if err != nil {
129+
return nil, dbsqlerrint.NewDriverError(ctx, "error performing staging operation", err)
130+
}
131+
isStagingOperation = *resp.IsStagingOperation
132+
}
133+
if isStagingOperation {
134+
if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
135+
row, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
136+
if err != nil {
137+
return nil, dbsqlerrint.NewDriverError(ctx, "error reading row.", err)
138+
}
139+
err = c.ExecStagingOperation(ctx, row)
140+
if err != nil {
141+
return nil, err
142+
}
143+
} else {
144+
return nil, dbsqlerrint.NewDriverError(ctx, "staging ctx must be provided.", nil)
145+
}
146+
}
147+
111148
// we have an operation id so update the logger
112149
log = logger.WithContext(c.id, corrId, client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID))
113150

@@ -133,6 +170,163 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
133170
return &res, nil
134171
}
135172

173+
func Succeeded(response *http.Response) bool {
174+
if response.StatusCode == 200 || response.StatusCode == 201 || response.StatusCode == 202 || response.StatusCode == 204 {
175+
return true
176+
}
177+
return false
178+
}
179+
180+
func (c *conn) HandleStagingPut(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError {
181+
if localFile == "" {
182+
return dbsqlerrint.NewDriverError(ctx, "cannot perform PUT without specifying a local_file", nil)
183+
}
184+
client := &http.Client{}
185+
186+
dat, err := os.ReadFile(localFile)
187+
188+
if err != nil {
189+
return dbsqlerrint.NewDriverError(ctx, "error reading local file", err)
190+
}
191+
192+
req, _ := http.NewRequest("PUT", presignedUrl, bytes.NewReader(dat))
193+
194+
for k, v := range headers {
195+
req.Header.Set(k, v)
196+
}
197+
res, err := client.Do(req)
198+
if err != nil {
199+
return dbsqlerrint.NewDriverError(ctx, "error sending http request", err)
200+
}
201+
defer res.Body.Close()
202+
content, err := io.ReadAll(res.Body)
203+
204+
if err != nil || !Succeeded(res) {
205+
return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil)
206+
}
207+
return nil
208+
209+
}
210+
211+
func (c *conn) HandleStagingGet(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError {
212+
if localFile == "" {
213+
return dbsqlerrint.NewDriverError(ctx, "cannot perform GET without specifying a local_file", nil)
214+
}
215+
client := &http.Client{}
216+
req, _ := http.NewRequest("GET", presignedUrl, nil)
217+
218+
for k, v := range headers {
219+
req.Header.Set(k, v)
220+
}
221+
res, err := client.Do(req)
222+
if err != nil {
223+
return dbsqlerrint.NewDriverError(ctx, "error sending http request", err)
224+
}
225+
defer res.Body.Close()
226+
content, err := io.ReadAll(res.Body)
227+
228+
if err != nil || !Succeeded(res) {
229+
return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil)
230+
}
231+
232+
err = os.WriteFile(localFile, content, 0644) //nolint:gosec
233+
if err != nil {
234+
return dbsqlerrint.NewDriverError(ctx, "error writing local file", err)
235+
}
236+
return nil
237+
}
238+
239+
func (c *conn) HandleStagingDelete(ctx context.Context, presignedUrl string, headers map[string]string) dbsqlerr.DBError {
240+
client := &http.Client{}
241+
req, _ := http.NewRequest("DELETE", presignedUrl, nil)
242+
for k, v := range headers {
243+
req.Header.Set(k, v)
244+
}
245+
res, err := client.Do(req)
246+
if err != nil {
247+
return dbsqlerrint.NewDriverError(ctx, "error sending http request", err)
248+
}
249+
defer res.Body.Close()
250+
content, err := io.ReadAll(res.Body)
251+
252+
if err != nil || !Succeeded(res) {
253+
return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s, nil", res.StatusCode, content), nil)
254+
}
255+
256+
return nil
257+
}
258+
259+
func localPathIsAllowed(stagingAllowedLocalPaths []string, localFile string) bool {
260+
for i := range stagingAllowedLocalPaths {
261+
// Convert both filepaths to absolute paths to avoid potential issues.
262+
//
263+
path, err := filepath.Abs(stagingAllowedLocalPaths[i])
264+
if err != nil {
265+
return false
266+
}
267+
localFile, err := filepath.Abs(localFile)
268+
if err != nil {
269+
return false
270+
}
271+
relativePath, err := filepath.Rel(path, localFile)
272+
if err != nil {
273+
return false
274+
}
275+
if !strings.Contains(relativePath, "../") {
276+
return true
277+
}
278+
}
279+
return false
280+
}
281+
282+
func (c *conn) ExecStagingOperation(
283+
ctx context.Context,
284+
row driver.Rows) dbsqlerr.DBError {
285+
286+
var sqlRow []driver.Value
287+
colNames := row.Columns()
288+
sqlRow = make([]driver.Value, len(colNames))
289+
err := row.Next(sqlRow)
290+
if err != nil {
291+
return dbsqlerrint.NewDriverError(ctx, "error fetching staging operation results", err)
292+
}
293+
var stringValues []string = make([]string, 4)
294+
for i := range stringValues {
295+
if s, ok := sqlRow[i].(string); ok {
296+
stringValues[i] = s
297+
} else {
298+
return dbsqlerrint.NewDriverError(ctx, "received unexpected response from the server.", nil)
299+
}
300+
}
301+
operation := stringValues[0]
302+
presignedUrl := stringValues[1]
303+
headersByteArr := []byte(stringValues[2])
304+
var headers map[string]string
305+
if err := json.Unmarshal(headersByteArr, &headers); err != nil {
306+
return dbsqlerrint.NewDriverError(ctx, "error parsing server response.", nil)
307+
}
308+
localFile := stringValues[3]
309+
stagingAllowedLocalPaths := driverctx.StagingPathsFromContext(ctx)
310+
switch operation {
311+
case "PUT":
312+
if localPathIsAllowed(stagingAllowedLocalPaths, localFile) {
313+
return c.HandleStagingPut(ctx, presignedUrl, headers, localFile)
314+
} else {
315+
return dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil)
316+
}
317+
case "GET":
318+
if localPathIsAllowed(stagingAllowedLocalPaths, localFile) {
319+
return c.HandleStagingGet(ctx, presignedUrl, headers, localFile)
320+
} else {
321+
return dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil)
322+
}
323+
case "DELETE":
324+
return c.HandleStagingDelete(ctx, presignedUrl, headers)
325+
default:
326+
return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("operation %s is not supported. Supported operations are GET, PUT, and REMOVE", operation), nil)
327+
}
328+
}
329+
136330
// QueryContext executes a query that may return rows, such as a
137331
// SELECT.
138332
//

connection_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ func TestConn_executeStatement(t *testing.T) {
121121
},
122122
}
123123

124+
getResultSetMetadata := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (_r *cli_service.TGetResultSetMetadataResp, _err error) {
125+
var b = false
126+
return &cli_service.TGetResultSetMetadataResp{IsStagingOperation: &b}, nil
127+
}
128+
124129
testClient := &client.TestClient{
125130
FnExecuteStatement: func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
126131
executeStatementCount++
@@ -130,6 +135,7 @@ func TestConn_executeStatement(t *testing.T) {
130135
closeOperationCount++
131136
return &cli_service.TCloseOperationResp{}, nil
132137
},
138+
FnGetResultSetMetadata: getResultSetMetadata,
133139
}
134140
testConn := &conn{
135141
session: getTestSession(),
@@ -1103,6 +1109,10 @@ func TestConn_ExecContext(t *testing.T) {
11031109
}
11041110
return getOperationStatusResp, nil
11051111
}
1112+
getResultSetMetadata := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (_r *cli_service.TGetResultSetMetadataResp, _err error) {
1113+
var b = false
1114+
return &cli_service.TGetResultSetMetadataResp{IsStagingOperation: &b}, nil
1115+
}
11061116

11071117
testClient := &client.TestClient{
11081118
FnExecuteStatement: executeStatement,
@@ -1112,6 +1122,7 @@ func TestConn_ExecContext(t *testing.T) {
11121122
assert.NoError(t, ctxErr)
11131123
return &cli_service.TCloseOperationResp{}, nil
11141124
},
1125+
FnGetResultSetMetadata: getResultSetMetadata,
11151126
}
11161127
testConn := &conn{
11171128
session: getTestSession(),
@@ -1155,6 +1166,11 @@ func TestConn_ExecContext(t *testing.T) {
11551166
return getOperationStatusResp, nil
11561167
}
11571168

1169+
getResultSetMetadata := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (_r *cli_service.TGetResultSetMetadataResp, _err error) {
1170+
var b = false
1171+
return &cli_service.TGetResultSetMetadataResp{IsStagingOperation: &b}, nil
1172+
}
1173+
11581174
testClient := &client.TestClient{
11591175
FnExecuteStatement: executeStatement,
11601176
FnGetOperationStatus: getOperationStatus,
@@ -1173,7 +1189,9 @@ func TestConn_ExecContext(t *testing.T) {
11731189
}
11741190
return cancelOperationResp, nil
11751191
},
1192+
FnGetResultSetMetadata: getResultSetMetadata,
11761193
}
1194+
11771195
testConn := &conn{
11781196
session: getTestSession(),
11791197
client: testClient,

driver_e2e_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ func TestWorkflowExample(t *testing.T) {
3131
)
3232
state := &callState{}
3333
// load basic responses
34+
loadTestData(t, "GetResultSetMetadataNotStaging.json", &state.getResultSetMetadataResp)
3435
loadTestData(t, "OpenSessionSuccess.json", &state.openSessionResp)
3536
loadTestData(t, "CloseSessionSuccess.json", &state.closeSessionResp)
3637
loadTestData(t, "CloseOperationSuccess.json", &state.closeOperationResp)
@@ -259,6 +260,7 @@ func TestContextTimeoutExample(t *testing.T) {
259260
_ = logger.SetLogLevel("debug")
260261
state := &callState{}
261262
// load basic responses
263+
loadTestData(t, "GetResultSetMetadataNotStaging.json", &state.getResultSetMetadataResp)
262264
loadTestData(t, "OpenSessionSuccess.json", &state.openSessionResp)
263265
loadTestData(t, "CloseSessionSuccess.json", &state.closeSessionResp)
264266
loadTestData(t, "CloseOperationSuccess.json", &state.closeOperationResp)
@@ -308,6 +310,7 @@ func TestRetries(t *testing.T) {
308310
_ = logger.SetLogLevel("debug")
309311
state := &callState{}
310312
// load basic responses
313+
loadTestData(t, "GetResultSetMetadataNotStaging.json", &state.getResultSetMetadataResp)
311314
loadTestData(t, "OpenSessionSuccess.json", &state.openSessionResp)
312315
loadTestData(t, "CloseSessionSuccess.json", &state.closeSessionResp)
313316
loadTestData(t, "CloseOperationSuccess.json", &state.closeOperationResp)
@@ -333,6 +336,7 @@ func TestRetries(t *testing.T) {
333336
_ = logger.SetLogLevel("debug")
334337
state := &callState{}
335338
// load basic responses
339+
loadTestData(t, "GetResultSetMetadataNotStaging.json", &state.getResultSetMetadataResp)
336340
loadTestData(t, "OpenSessionSuccess.json", &state.openSessionResp)
337341
loadTestData(t, "CloseSessionSuccess.json", &state.closeSessionResp)
338342
loadTestData(t, "CloseOperationSuccess.json", &state.closeOperationResp)
@@ -358,6 +362,7 @@ func TestRetries(t *testing.T) {
358362
_ = logger.SetLogLevel("debug")
359363
state := &callState{}
360364
// load basic responses
365+
loadTestData(t, "GetResultSetMetadataNotStaging.json", &state.getResultSetMetadataResp)
361366
loadTestData(t, "OpenSessionSuccess.json", &state.openSessionResp)
362367
loadTestData(t, "CloseSessionSuccess.json", &state.closeSessionResp)
363368
loadTestData(t, "CloseOperationSuccess.json", &state.closeOperationResp)
@@ -392,6 +397,7 @@ func TestRetries(t *testing.T) {
392397
_ = logger.SetLogLevel("debug")
393398
state := &callState{}
394399
// load basic responses
400+
loadTestData(t, "GetResultSetMetadataNotStaging.json", &state.getResultSetMetadataResp)
395401
loadTestData(t, "OpenSessionSuccess.json", &state.openSessionResp)
396402
loadTestData(t, "CloseSessionSuccess.json", &state.closeSessionResp)
397403
loadTestData(t, "CloseOperationSuccess.json", &state.closeOperationResp)
@@ -426,6 +432,7 @@ func TestRetries(t *testing.T) {
426432
_ = logger.SetLogLevel("debug")
427433
state := &callState{}
428434
// load basic responses
435+
loadTestData(t, "GetResultSetMetadataNotStaging.json", &state.getResultSetMetadataResp)
429436
loadTestData(t, "OpenSessionSuccess.json", &state.openSessionResp)
430437
loadTestData(t, "CloseSessionSuccess.json", &state.closeSessionResp)
431438
loadTestData(t, "CloseOperationSuccess.json", &state.closeOperationResp)

driverctx/ctx.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ const (
1414
QueryIdContextKey
1515
QueryIdCallbackKey
1616
ConnIdCallbackKey
17+
StagingAllowedLocalPathKey
1718
)
1819

1920
type IdCallbackFunc func(string)
@@ -79,10 +80,27 @@ func QueryIdFromContext(ctx context.Context) string {
7980
return queryId
8081
}
8182

83+
// QueryIdFromContext retrieves the queryId stored in context.
84+
func StagingPathsFromContext(ctx context.Context) []string {
85+
if ctx == nil {
86+
return []string{}
87+
}
88+
89+
stagingAllowedLocalPath, ok := ctx.Value(StagingAllowedLocalPathKey).([]string)
90+
if !ok {
91+
return []string{}
92+
}
93+
return stagingAllowedLocalPath
94+
}
95+
8296
func NewContextWithQueryIdCallback(ctx context.Context, callback IdCallbackFunc) context.Context {
8397
return context.WithValue(ctx, QueryIdCallbackKey, callback)
8498
}
8599

86100
func NewContextWithConnIdCallback(ctx context.Context, callback IdCallbackFunc) context.Context {
87101
return context.WithValue(ctx, ConnIdCallbackKey, callback)
88102
}
103+
104+
func NewContextWithStagingInfo(ctx context.Context, stagingAllowedLocalPath []string) context.Context {
105+
return context.WithValue(ctx, StagingAllowedLocalPathKey, stagingAllowedLocalPath)
106+
}

0 commit comments

Comments
 (0)