88 "github.com/apache/arrow/go/v12/arrow/array"
99 "github.com/apache/arrow/go/v12/arrow/ipc"
1010 "github.com/apache/arrow/go/v12/arrow/memory"
11+ dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1112 "github.com/databricks/databricks-sql-go/internal/cli_service"
1213 "github.com/databricks/databricks-sql-go/internal/config"
1314 dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
@@ -17,6 +18,7 @@ import (
1718 "net/http/httptest"
1819 "reflect"
1920 "testing"
21+ "time"
2022)
2123
2224func generateMockArrowBytes () []byte {
@@ -58,15 +60,20 @@ func TestBatchLoader(t *testing.T) {
5860 testTable := []struct {
5961 name string
6062 response func (w http.ResponseWriter , r * http.Request )
63+ linkExpired bool
6164 expectedResponse []* sparkArrowBatch
6265 expectedErr error
6366 }{
6467 {
6568 name : "cloud-fetch-happy-case" ,
6669 response : func (w http.ResponseWriter , r * http.Request ) {
6770 w .WriteHeader (http .StatusOK )
68- w .Write (generateMockArrowBytes ())
71+ _ , err := w .Write (generateMockArrowBytes ())
72+ if err != nil {
73+ panic (err )
74+ }
6975 },
76+ linkExpired : false ,
7077 expectedResponse : []* sparkArrowBatch {
7178 {
7279 arrowRecordBytes : generateMockArrowBytes (),
@@ -78,11 +85,25 @@ func TestBatchLoader(t *testing.T) {
7885 },
7986 expectedErr : nil ,
8087 },
88+ {
89+ name : "cloud-fetch-expired_link" ,
90+ response : func (w http.ResponseWriter , r * http.Request ) {
91+ w .WriteHeader (http .StatusOK )
92+ _ , err := w .Write (generateMockArrowBytes ())
93+ if err != nil {
94+ panic (err )
95+ }
96+ },
97+ linkExpired : true ,
98+ expectedResponse : nil ,
99+ expectedErr : errors .New (dbsqlerr .ErrLinkExpired ),
100+ },
81101 {
82102 name : "cloud-fetch-http-error" ,
83103 response : func (w http.ResponseWriter , r * http.Request ) {
84104 w .WriteHeader (http .StatusInternalServerError )
85105 },
106+ linkExpired : false ,
86107 expectedResponse : nil ,
87108 expectedErr : dbsqlerrint .NewDriverError (context .TODO (), errArrowRowsCloudFetchDownloadFailure , nil ),
88109 },
@@ -92,9 +113,18 @@ func TestBatchLoader(t *testing.T) {
92113 t .Run (tc .name , func (t * testing.T ) {
93114 handler = tc .response
94115
116+ expiryTime := time .Now ()
117+ // If link expired, subtract 1 sec from current time to get expiration time
118+ if tc .linkExpired {
119+ expiryTime = expiryTime .Add (- 1 * time .Second )
120+ } else {
121+ expiryTime = expiryTime .Add (1 * time .Second )
122+ }
123+
95124 cu := & cloudURL {
96125 TSparkArrowResultLink : & cli_service.TSparkArrowResultLink {
97- FileLink : server .URL ,
126+ FileLink : server .URL ,
127+ ExpiryTime : expiryTime .Unix (),
98128 },
99129 }
100130
0 commit comments