Skip to content

Commit fbb57d0

Browse files
authored
Increased coverage to statement and connector (#54)
Increased coverage via unit tests to 89.1% with tests added for `statement.go`, `connector.go` Signed-off-by: Matthew Kim <[email protected]>
2 parents 73934fc + 0eb87c7 commit fbb57d0

File tree

5 files changed

+362
-1
lines changed

5 files changed

+362
-1
lines changed

connection_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,90 @@ func TestConn_ResetSession(t *testing.T) {
10871087
})
10881088
}
10891089

1090+
func TestConn_Close(t *testing.T) {
1091+
t.Run("Close will call CloseSession", func(t *testing.T) {
1092+
var closeSessionCount int
1093+
1094+
closeSession := func(ctx context.Context, req *cli_service.TCloseSessionReq) (r *cli_service.TCloseSessionResp, err error) {
1095+
closeSessionCount++
1096+
closeSessionResp := &cli_service.TCloseSessionResp{
1097+
Status: &cli_service.TStatus{
1098+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
1099+
},
1100+
}
1101+
return closeSessionResp, nil
1102+
}
1103+
1104+
testClient := &client.TestClient{
1105+
FnCloseSession: closeSession,
1106+
}
1107+
testConn := &conn{
1108+
session: getTestSession(),
1109+
client: testClient,
1110+
cfg: config.WithDefaults(),
1111+
}
1112+
err := testConn.Close()
1113+
1114+
assert.NoError(t, err)
1115+
assert.Equal(t, 1, closeSessionCount)
1116+
})
1117+
1118+
t.Run("Close will err when CloseSession fails", func(t *testing.T) {
1119+
var closeSessionCount int
1120+
1121+
closeSession := func(ctx context.Context, req *cli_service.TCloseSessionReq) (r *cli_service.TCloseSessionResp, err error) {
1122+
closeSessionCount++
1123+
closeSessionResp := &cli_service.TCloseSessionResp{
1124+
Status: &cli_service.TStatus{
1125+
StatusCode: cli_service.TStatusCode_ERROR_STATUS,
1126+
},
1127+
}
1128+
return closeSessionResp, fmt.Errorf("error")
1129+
}
1130+
1131+
testClient := &client.TestClient{
1132+
FnCloseSession: closeSession,
1133+
}
1134+
testConn := &conn{
1135+
session: getTestSession(),
1136+
client: testClient,
1137+
cfg: config.WithDefaults(),
1138+
}
1139+
err := testConn.Close()
1140+
1141+
assert.Error(t, err)
1142+
assert.Equal(t, 1, closeSessionCount)
1143+
})
1144+
}
1145+
1146+
func TestConn_Prepare(t *testing.T) {
1147+
t.Run("Prepare returns stmt struct", func(t *testing.T) {
1148+
testClient := &client.TestClient{}
1149+
testConn := &conn{
1150+
session: getTestSession(),
1151+
client: testClient,
1152+
cfg: config.WithDefaults(),
1153+
}
1154+
stmt, err := testConn.Prepare("query string")
1155+
assert.NoError(t, err)
1156+
assert.NotNil(t, stmt)
1157+
})
1158+
}
1159+
1160+
func TestConn_PrepareContext(t *testing.T) {
1161+
t.Run("PrepareContext returns stmt struct", func(t *testing.T) {
1162+
testClient := &client.TestClient{}
1163+
testConn := &conn{
1164+
session: getTestSession(),
1165+
client: testClient,
1166+
cfg: config.WithDefaults(),
1167+
}
1168+
stmt, err := testConn.PrepareContext(context.Background(), "query string")
1169+
assert.NoError(t, err)
1170+
assert.NotNil(t, stmt)
1171+
})
1172+
}
1173+
10901174
func getTestSession() *cli_service.TOpenSessionResp {
10911175
return &cli_service.TOpenSessionResp{SessionHandle: &cli_service.TSessionHandle{
10921176
SessionId: &cli_service.THandleIdentifier{

connector_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package dbsql
2+
3+
import (
4+
"context"
5+
"github.com/databricks/databricks-sql-go/internal/config"
6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
"testing"
9+
"time"
10+
)
11+
12+
func TestConnector_Connect(t *testing.T) {
13+
t.Run("Connect returns err when thrift client initialization fails", func(t *testing.T) {
14+
cfg := config.WithDefaults()
15+
cfg.ThriftProtocol = "invalidprotocol"
16+
17+
testConnector := connector{
18+
cfg: cfg,
19+
}
20+
conn, err := testConnector.Connect(context.Background())
21+
assert.Nil(t, conn)
22+
assert.Error(t, err)
23+
})
24+
}
25+
26+
func TestNewConnector(t *testing.T) {
27+
t.Run("Connector initialized with functional options should have all options set", func(t *testing.T) {
28+
host := "databricks-host"
29+
port := 1
30+
accessToken := "token"
31+
httpPath := "http-path"
32+
maxRows := 100
33+
timeout := 100 * time.Second
34+
catalog := "catalog-name"
35+
schema := "schema-string"
36+
userAgentEntry := "user-agent"
37+
sessionParams := map[string]string{"key": "value"}
38+
con, err := NewConnector(
39+
WithServerHostname(host),
40+
WithPort(port),
41+
WithAccessToken(accessToken),
42+
WithHTTPPath(httpPath),
43+
WithMaxRows(maxRows),
44+
WithTimeout(timeout),
45+
WithInitialNamespace(catalog, schema),
46+
WithUserAgentEntry(userAgentEntry),
47+
WithSessionParams(sessionParams),
48+
)
49+
expectedUserConfig := config.UserConfig{
50+
Host: host,
51+
Port: port,
52+
Protocol: "https",
53+
AccessToken: accessToken,
54+
HTTPPath: httpPath,
55+
MaxRows: maxRows,
56+
QueryTimeout: timeout,
57+
Catalog: catalog,
58+
Schema: schema,
59+
UserAgentEntry: userAgentEntry,
60+
SessionParams: sessionParams,
61+
}
62+
expectedCfg := config.WithDefaults()
63+
expectedCfg.UserConfig = expectedUserConfig
64+
coni, ok := con.(*connector)
65+
require.True(t, ok)
66+
assert.Nil(t, err)
67+
assert.Equal(t, expectedCfg, coni.cfg)
68+
})
69+
}

internal/config/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func (ucfg UserConfig) DeepCopy() UserConfig {
119119
}
120120

121121
func (ucfg UserConfig) WithDefaults() UserConfig {
122-
if ucfg.MaxRows == 0 {
122+
if ucfg.MaxRows <= 0 {
123123
ucfg.MaxRows = 10000
124124
}
125125
if ucfg.Protocol == "" {

rows_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ func TestColumnsWithDirectResults(t *testing.T) {
481481
var getMetadataCount, fetchResultsCount int
482482

483483
rowSet := &rows{}
484+
defer rowSet.Close()
484485
client := getRowsTestSimpleClient(&getMetadataCount, &fetchResultsCount)
485486

486487
req := &cli_service.TFetchResultsReq{
@@ -757,6 +758,47 @@ func TestColumnTypeLength(t *testing.T) {
757758
}
758759
}
759760

761+
func TestColumnTypeDatabaseTypeName(t *testing.T) {
762+
var getMetadataCount, fetchResultsCount int
763+
764+
rowSet := &rows{}
765+
client := getRowsTestSimpleClient(&getMetadataCount, &fetchResultsCount)
766+
rowSet.client = client
767+
768+
resp, err := rowSet.getResultMetadata()
769+
assert.Nil(t, err)
770+
771+
cols := resp.Schema.Columns
772+
expectedScanTypes := []reflect.Type{
773+
scanTypeBoolean,
774+
scanTypeInt8,
775+
scanTypeInt16,
776+
scanTypeInt32,
777+
scanTypeInt64,
778+
scanTypeFloat32,
779+
scanTypeFloat64,
780+
scanTypeString,
781+
scanTypeDateTime,
782+
scanTypeRawBytes,
783+
scanTypeRawBytes,
784+
scanTypeRawBytes,
785+
scanTypeRawBytes,
786+
scanTypeRawBytes,
787+
scanTypeDateTime,
788+
scanTypeString,
789+
scanTypeString,
790+
}
791+
792+
assert.Equal(t, len(expectedScanTypes), len(cols))
793+
794+
scanTypes := make([]reflect.Type, len(cols))
795+
for i := range cols {
796+
scanTypes[i] = rowSet.ColumnTypeScanType(i)
797+
}
798+
799+
assert.Equal(t, expectedScanTypes, scanTypes)
800+
}
801+
760802
type rowTestPagingResult struct {
761803
getMetadataCount int
762804
fetchResultsCount int

statement_test.go

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package dbsql
2+
3+
import (
4+
"context"
5+
"database/sql/driver"
6+
"github.com/apache/thrift/lib/go/thrift"
7+
"github.com/databricks/databricks-sql-go/internal/cli_service"
8+
"github.com/databricks/databricks-sql-go/internal/client"
9+
"github.com/databricks/databricks-sql-go/internal/config"
10+
"github.com/stretchr/testify/assert"
11+
"testing"
12+
)
13+
14+
func TestStmt_Close(t *testing.T) {
15+
t.Run("Close is not applicable", func(t *testing.T) {
16+
testStmt := stmt{
17+
conn: &conn{},
18+
query: "query string",
19+
}
20+
err := testStmt.Close()
21+
assert.Nil(t, err)
22+
})
23+
}
24+
25+
func TestStmt_NumInput(t *testing.T) {
26+
t.Run("NumInput is not applicable", func(t *testing.T) {
27+
testStmt := stmt{
28+
conn: &conn{},
29+
query: "query string",
30+
}
31+
numInput := testStmt.NumInput()
32+
assert.Equal(t, -1, numInput)
33+
})
34+
}
35+
36+
func TestStmt_Exec(t *testing.T) {
37+
t.Run("Exec is not implemented", func(t *testing.T) {
38+
testStmt := stmt{
39+
conn: &conn{},
40+
query: "query string",
41+
}
42+
res, err := testStmt.Exec([]driver.Value{})
43+
assert.Nil(t, res)
44+
assert.Error(t, err)
45+
})
46+
}
47+
48+
func TestStmt_Query(t *testing.T) {
49+
t.Run("Query is not implemented", func(t *testing.T) {
50+
testStmt := stmt{
51+
conn: &conn{},
52+
query: "query string",
53+
}
54+
res, err := testStmt.Query([]driver.Value{})
55+
assert.Nil(t, res)
56+
assert.Error(t, err)
57+
})
58+
}
59+
60+
func TestStmt_ExecContext(t *testing.T) {
61+
t.Run("ExecContext returns number of rows modified when execution is successful", func(t *testing.T) {
62+
var executeStatementCount, getOperationStatusCount int
63+
var savedQueryString string
64+
executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
65+
executeStatementCount++
66+
savedQueryString = req.Statement
67+
executeStatementResp := &cli_service.TExecuteStatementResp{
68+
Status: &cli_service.TStatus{
69+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
70+
},
71+
OperationHandle: &cli_service.TOperationHandle{
72+
OperationId: &cli_service.THandleIdentifier{
73+
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 2, 3, 4, 4, 223, 34, 54},
74+
Secret: []byte("b"),
75+
},
76+
},
77+
}
78+
return executeStatementResp, nil
79+
}
80+
81+
getOperationStatus := func(ctx context.Context, req *cli_service.TGetOperationStatusReq) (r *cli_service.TGetOperationStatusResp, err error) {
82+
getOperationStatusCount++
83+
getOperationStatusResp := &cli_service.TGetOperationStatusResp{
84+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
85+
NumModifiedRows: thrift.Int64Ptr(10),
86+
}
87+
return getOperationStatusResp, nil
88+
}
89+
90+
testClient := &client.TestClient{
91+
FnExecuteStatement: executeStatement,
92+
FnGetOperationStatus: getOperationStatus,
93+
}
94+
testConn := &conn{
95+
session: getTestSession(),
96+
client: testClient,
97+
cfg: config.WithDefaults(),
98+
}
99+
testQuery := "insert 10"
100+
testStmt := &stmt{
101+
conn: testConn,
102+
query: testQuery,
103+
}
104+
res, err := testStmt.ExecContext(context.Background(), []driver.NamedValue{})
105+
106+
assert.NoError(t, err)
107+
assert.NotNil(t, res)
108+
rowsAffected, _ := res.RowsAffected()
109+
assert.Equal(t, int64(10), rowsAffected)
110+
assert.Equal(t, 1, executeStatementCount)
111+
assert.Equal(t, testQuery, savedQueryString)
112+
})
113+
}
114+
115+
func TestStmt_QueryContext(t *testing.T) {
116+
t.Run("QueryContext returns rows object upon successful query", func(t *testing.T) {
117+
var executeStatementCount, getOperationStatusCount int
118+
var savedQueryString string
119+
executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
120+
executeStatementCount++
121+
savedQueryString = req.Statement
122+
executeStatementResp := &cli_service.TExecuteStatementResp{
123+
Status: &cli_service.TStatus{
124+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
125+
},
126+
OperationHandle: &cli_service.TOperationHandle{
127+
OperationId: &cli_service.THandleIdentifier{
128+
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 2, 3, 4, 4, 223, 34, 54},
129+
Secret: []byte("b"),
130+
},
131+
},
132+
}
133+
return executeStatementResp, nil
134+
}
135+
136+
getOperationStatus := func(ctx context.Context, req *cli_service.TGetOperationStatusReq) (r *cli_service.TGetOperationStatusResp, err error) {
137+
getOperationStatusCount++
138+
getOperationStatusResp := &cli_service.TGetOperationStatusResp{
139+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
140+
NumModifiedRows: thrift.Int64Ptr(10),
141+
}
142+
return getOperationStatusResp, nil
143+
}
144+
145+
testClient := &client.TestClient{
146+
FnExecuteStatement: executeStatement,
147+
FnGetOperationStatus: getOperationStatus,
148+
}
149+
testConn := &conn{
150+
session: getTestSession(),
151+
client: testClient,
152+
cfg: config.WithDefaults(),
153+
}
154+
testQuery := "select 1"
155+
testStmt := &stmt{
156+
conn: testConn,
157+
query: testQuery,
158+
}
159+
rows, err := testStmt.QueryContext(context.Background(), []driver.NamedValue{})
160+
161+
assert.NoError(t, err)
162+
assert.NotNil(t, rows)
163+
assert.Equal(t, 1, executeStatementCount)
164+
assert.Equal(t, testQuery, savedQueryString)
165+
})
166+
}

0 commit comments

Comments
 (0)