11package dbsql
22
33import (
4+ "bytes"
45 "context"
56 "database/sql/driver"
7+ "encoding/json"
8+ "fmt"
9+ "io"
10+ "net/http"
11+ "os"
12+ "path/filepath"
13+ "strings"
614 "time"
715
816 "github.com/databricks/databricks-sql-go/driverctx"
@@ -94,6 +102,7 @@ func (c *conn) IsValid() bool {
94102// ExecContext honors the context timeout and return when it is canceled.
95103// Statement ExecContext is the same as connection ExecContext
96104func (c * conn ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
105+
97106 corrId := driverctx .CorrelationIdFromContext (ctx )
98107 log := logger .WithContext (c .id , corrId , "" )
99108 msg , start := logger .Track ("ExecContext" )
@@ -108,6 +117,34 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
108117 exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args )
109118
110119 if exStmtResp != nil && exStmtResp .OperationHandle != nil {
120+ var isStagingOperation bool
121+ if exStmtResp .DirectResults != nil && exStmtResp .DirectResults .ResultSetMetadata != nil && exStmtResp .DirectResults .ResultSetMetadata .IsStagingOperation != nil {
122+ isStagingOperation = * exStmtResp .DirectResults .ResultSetMetadata .IsStagingOperation
123+ } else {
124+ req := cli_service.TGetResultSetMetadataReq {
125+ OperationHandle : exStmtResp .OperationHandle ,
126+ }
127+ resp , err := c .client .GetResultSetMetadata (ctx , & req )
128+ if err != nil {
129+ return nil , dbsqlerrint .NewDriverError (ctx , "error performing staging operation" , err )
130+ }
131+ isStagingOperation = * resp .IsStagingOperation
132+ }
133+ if isStagingOperation {
134+ if len (driverctx .StagingPathsFromContext (ctx )) != 0 {
135+ row , err := rows .NewRows (c .id , corrId , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults )
136+ if err != nil {
137+ return nil , dbsqlerrint .NewDriverError (ctx , "error reading row." , err )
138+ }
139+ err = c .ExecStagingOperation (ctx , row )
140+ if err != nil {
141+ return nil , err
142+ }
143+ } else {
144+ return nil , dbsqlerrint .NewDriverError (ctx , "staging ctx must be provided." , nil )
145+ }
146+ }
147+
111148 // we have an operation id so update the logger
112149 log = logger .WithContext (c .id , corrId , client .SprintGuid (exStmtResp .OperationHandle .OperationId .GUID ))
113150
@@ -133,6 +170,163 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
133170 return & res , nil
134171}
135172
173+ func Succeeded (response * http.Response ) bool {
174+ if response .StatusCode == 200 || response .StatusCode == 201 || response .StatusCode == 202 || response .StatusCode == 204 {
175+ return true
176+ }
177+ return false
178+ }
179+
180+ func (c * conn ) HandleStagingPut (ctx context.Context , presignedUrl string , headers map [string ]string , localFile string ) dbsqlerr.DBError {
181+ if localFile == "" {
182+ return dbsqlerrint .NewDriverError (ctx , "cannot perform PUT without specifying a local_file" , nil )
183+ }
184+ client := & http.Client {}
185+
186+ dat , err := os .ReadFile (localFile )
187+
188+ if err != nil {
189+ return dbsqlerrint .NewDriverError (ctx , "error reading local file" , err )
190+ }
191+
192+ req , _ := http .NewRequest ("PUT" , presignedUrl , bytes .NewReader (dat ))
193+
194+ for k , v := range headers {
195+ req .Header .Set (k , v )
196+ }
197+ res , err := client .Do (req )
198+ if err != nil {
199+ return dbsqlerrint .NewDriverError (ctx , "error sending http request" , err )
200+ }
201+ defer res .Body .Close ()
202+ content , err := io .ReadAll (res .Body )
203+
204+ if err != nil || ! Succeeded (res ) {
205+ return dbsqlerrint .NewDriverError (ctx , fmt .Sprintf ("staging operation over HTTP was unsuccessful: %d-%s" , res .StatusCode , content ), nil )
206+ }
207+ return nil
208+
209+ }
210+
211+ func (c * conn ) HandleStagingGet (ctx context.Context , presignedUrl string , headers map [string ]string , localFile string ) dbsqlerr.DBError {
212+ if localFile == "" {
213+ return dbsqlerrint .NewDriverError (ctx , "cannot perform GET without specifying a local_file" , nil )
214+ }
215+ client := & http.Client {}
216+ req , _ := http .NewRequest ("GET" , presignedUrl , nil )
217+
218+ for k , v := range headers {
219+ req .Header .Set (k , v )
220+ }
221+ res , err := client .Do (req )
222+ if err != nil {
223+ return dbsqlerrint .NewDriverError (ctx , "error sending http request" , err )
224+ }
225+ defer res .Body .Close ()
226+ content , err := io .ReadAll (res .Body )
227+
228+ if err != nil || ! Succeeded (res ) {
229+ return dbsqlerrint .NewDriverError (ctx , fmt .Sprintf ("staging operation over HTTP was unsuccessful: %d-%s" , res .StatusCode , content ), nil )
230+ }
231+
232+ err = os .WriteFile (localFile , content , 0644 ) //nolint:gosec
233+ if err != nil {
234+ return dbsqlerrint .NewDriverError (ctx , "error writing local file" , err )
235+ }
236+ return nil
237+ }
238+
239+ func (c * conn ) HandleStagingDelete (ctx context.Context , presignedUrl string , headers map [string ]string ) dbsqlerr.DBError {
240+ client := & http.Client {}
241+ req , _ := http .NewRequest ("DELETE" , presignedUrl , nil )
242+ for k , v := range headers {
243+ req .Header .Set (k , v )
244+ }
245+ res , err := client .Do (req )
246+ if err != nil {
247+ return dbsqlerrint .NewDriverError (ctx , "error sending http request" , err )
248+ }
249+ defer res .Body .Close ()
250+ content , err := io .ReadAll (res .Body )
251+
252+ if err != nil || ! Succeeded (res ) {
253+ return dbsqlerrint .NewDriverError (ctx , fmt .Sprintf ("staging operation over HTTP was unsuccessful: %d-%s, nil" , res .StatusCode , content ), nil )
254+ }
255+
256+ return nil
257+ }
258+
259+ func localPathIsAllowed (stagingAllowedLocalPaths []string , localFile string ) bool {
260+ for i := range stagingAllowedLocalPaths {
261+ // Convert both filepaths to absolute paths to avoid potential issues.
262+ //
263+ path , err := filepath .Abs (stagingAllowedLocalPaths [i ])
264+ if err != nil {
265+ return false
266+ }
267+ localFile , err := filepath .Abs (localFile )
268+ if err != nil {
269+ return false
270+ }
271+ relativePath , err := filepath .Rel (path , localFile )
272+ if err != nil {
273+ return false
274+ }
275+ if ! strings .Contains (relativePath , "../" ) {
276+ return true
277+ }
278+ }
279+ return false
280+ }
281+
282+ func (c * conn ) ExecStagingOperation (
283+ ctx context.Context ,
284+ row driver.Rows ) dbsqlerr.DBError {
285+
286+ var sqlRow []driver.Value
287+ colNames := row .Columns ()
288+ sqlRow = make ([]driver.Value , len (colNames ))
289+ err := row .Next (sqlRow )
290+ if err != nil {
291+ return dbsqlerrint .NewDriverError (ctx , "error fetching staging operation results" , err )
292+ }
293+ var stringValues []string = make ([]string , 4 )
294+ for i := range stringValues {
295+ if s , ok := sqlRow [i ].(string ); ok {
296+ stringValues [i ] = s
297+ } else {
298+ return dbsqlerrint .NewDriverError (ctx , "received unexpected response from the server." , nil )
299+ }
300+ }
301+ operation := stringValues [0 ]
302+ presignedUrl := stringValues [1 ]
303+ headersByteArr := []byte (stringValues [2 ])
304+ var headers map [string ]string
305+ if err := json .Unmarshal (headersByteArr , & headers ); err != nil {
306+ return dbsqlerrint .NewDriverError (ctx , "error parsing server response." , nil )
307+ }
308+ localFile := stringValues [3 ]
309+ stagingAllowedLocalPaths := driverctx .StagingPathsFromContext (ctx )
310+ switch operation {
311+ case "PUT" :
312+ if localPathIsAllowed (stagingAllowedLocalPaths , localFile ) {
313+ return c .HandleStagingPut (ctx , presignedUrl , headers , localFile )
314+ } else {
315+ return dbsqlerrint .NewDriverError (ctx , "local file operations are restricted to paths within the configured stagingAllowedLocalPath" , nil )
316+ }
317+ case "GET" :
318+ if localPathIsAllowed (stagingAllowedLocalPaths , localFile ) {
319+ return c .HandleStagingGet (ctx , presignedUrl , headers , localFile )
320+ } else {
321+ return dbsqlerrint .NewDriverError (ctx , "local file operations are restricted to paths within the configured stagingAllowedLocalPath" , nil )
322+ }
323+ case "DELETE" :
324+ return c .HandleStagingDelete (ctx , presignedUrl , headers )
325+ default :
326+ return dbsqlerrint .NewDriverError (ctx , fmt .Sprintf ("operation %s is not supported. Supported operations are GET, PUT, and REMOVE" , operation ), nil )
327+ }
328+ }
329+
136330// QueryContext executes a query that may return rows, such as a
137331// SELECT.
138332//
0 commit comments