Skip to content

Commit 99d8ec6

Browse files
committed
Use RAII helper for allocating and freeing env/conn handles
Avoids duplicated code
1 parent 10c48f9 commit 99d8ec6

File tree

3 files changed

+67
-76
lines changed

3 files changed

+67
-76
lines changed

cpp/src/arrow/flight/sql/odbc/tests/errors_test.cc

Lines changed: 27 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,14 @@ using TestTypesOdbcV2 =
4040
::testing::Types<FlightSQLOdbcV2MockTestBase, FlightSQLOdbcV2RemoteTestBase>;
4141
TYPED_TEST_SUITE(ErrorsOdbcV2Test, TestTypesOdbcV2);
4242

43-
TYPED_TEST(ErrorsTest, TestSQLGetDiagFieldWForConnectFailure) {
44-
SQLHENV env;
45-
SQLHDBC conn;
46-
47-
// Allocate an environment handle
48-
ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env));
49-
50-
ASSERT_EQ(SQL_SUCCESS,
51-
SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0));
43+
template <typename T>
44+
class ErrorsHandleTest : public T {};
5245

53-
// Allocate a connection using alloc handle
54-
ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DBC, env, &conn));
46+
using TestTypesHandle = ::testing::Types<FlightSQLOdbcEnvConnHandleMockTestBase,
47+
FlightSQLOdbcEnvConnHandleRemoteTestBase>;
48+
TYPED_TEST_SUITE(ErrorsHandleTest, TestTypesHandle);
5549

50+
TYPED_TEST(ErrorsHandleTest, TestSQLGetDiagFieldWForConnectFailure) {
5651
// Invalid connect string
5752
std::string connect_str = this->GetInvalidConnectionString();
5853

@@ -65,7 +60,7 @@ TYPED_TEST(ErrorsTest, TestSQLGetDiagFieldWForConnectFailure) {
6560

6661
// Connecting to ODBC server.
6762
ASSERT_EQ(SQL_ERROR,
68-
SQLDriverConnect(conn, NULL, &connect_str0[0],
63+
SQLDriverConnect(this->conn, NULL, &connect_str0[0],
6964
static_cast<SQLSMALLINT>(connect_str0.size()), out_str,
7065
kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT));
7166

@@ -78,7 +73,7 @@ TYPED_TEST(ErrorsTest, TestSQLGetDiagFieldWForConnectFailure) {
7873
SQLSMALLINT diag_number_length;
7974

8075
EXPECT_EQ(SQL_SUCCESS,
81-
SQLGetDiagField(SQL_HANDLE_DBC, conn, HEADER_LEVEL, SQL_DIAG_NUMBER,
76+
SQLGetDiagField(SQL_HANDLE_DBC, this->conn, HEADER_LEVEL, SQL_DIAG_NUMBER,
8277
&diag_number, sizeof(SQLINTEGER), &diag_number_length));
8378

8479
EXPECT_EQ(1, diag_number);
@@ -88,16 +83,16 @@ TYPED_TEST(ErrorsTest, TestSQLGetDiagFieldWForConnectFailure) {
8883
SQLSMALLINT server_name_length;
8984

9085
EXPECT_EQ(SQL_SUCCESS,
91-
SQLGetDiagField(SQL_HANDLE_DBC, conn, RECORD_1, SQL_DIAG_SERVER_NAME,
92-
server_name, kOdbcBufferSize, &server_name_length));
86+
SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_SERVER_NAME,
87+
server_name, ODBC_BUFFER_SIZE, &server_name_length));
9388

9489
// SQL_DIAG_MESSAGE_TEXT
9590
SQLWCHAR message_text[kOdbcBufferSize];
9691
SQLSMALLINT message_text_length;
9792

9893
EXPECT_EQ(SQL_SUCCESS,
99-
SQLGetDiagField(SQL_HANDLE_DBC, conn, RECORD_1, SQL_DIAG_MESSAGE_TEXT,
100-
message_text, kOdbcBufferSize, &message_text_length));
94+
SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_MESSAGE_TEXT,
95+
message_text, ODBC_BUFFER_SIZE, &message_text_length));
10196

10297
EXPECT_GT(message_text_length, 100);
10398

@@ -106,8 +101,8 @@ TYPED_TEST(ErrorsTest, TestSQLGetDiagFieldWForConnectFailure) {
106101
SQLSMALLINT diag_native_length;
107102

108103
EXPECT_EQ(SQL_SUCCESS,
109-
SQLGetDiagField(SQL_HANDLE_DBC, conn, RECORD_1, SQL_DIAG_NATIVE, &diag_native,
110-
sizeof(diag_native), &diag_native_length));
104+
SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_NATIVE,
105+
&diag_native, sizeof(diag_native), &diag_native_length));
111106

112107
EXPECT_EQ(200, diag_native);
113108

@@ -116,34 +111,18 @@ TYPED_TEST(ErrorsTest, TestSQLGetDiagFieldWForConnectFailure) {
116111
SQLWCHAR sql_state[sql_state_size];
117112
SQLSMALLINT sql_state_length;
118113

119-
EXPECT_EQ(SQL_SUCCESS,
120-
SQLGetDiagField(SQL_HANDLE_DBC, conn, RECORD_1, SQL_DIAG_SQLSTATE, sql_state,
121-
sql_state_size * arrow::flight::sql::odbc::GetSqlWCharSize(),
122-
&sql_state_length));
114+
EXPECT_EQ(
115+
SQL_SUCCESS,
116+
SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_SQLSTATE, sql_state,
117+
sql_state_size * arrow::flight::sql::odbc::GetSqlWCharSize(),
118+
&sql_state_length));
123119

124120
EXPECT_EQ(std::wstring(L"28000"), std::wstring(sql_state));
125-
126-
// Free connection handle
127-
EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DBC, conn));
128-
129-
// Free environment handle
130-
EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env));
131121
}
132122

133-
TYPED_TEST(ErrorsTest, DISABLED_TestSQLGetDiagFieldWForConnectFailureNTS) {
123+
TYPED_TEST(ErrorsHandleTest, DISABLED_TestSQLGetDiagFieldWForConnectFailureNTS) {
134124
// Test is disabled because driver manager on Windows does not pass through SQL_NTS
135125
// This test case can be potentially used on macOS/Linux
136-
SQLHENV env;
137-
SQLHDBC conn;
138-
139-
// Allocate an environment handle
140-
ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env));
141-
142-
ASSERT_EQ(SQL_SUCCESS,
143-
SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0));
144-
145-
// Allocate a connection using alloc handle
146-
ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DBC, env, &conn));
147126

148127
// Invalid connect string
149128
std::string connect_str = this->GetInvalidConnectionString();
@@ -157,7 +136,7 @@ TYPED_TEST(ErrorsTest, DISABLED_TestSQLGetDiagFieldWForConnectFailureNTS) {
157136

158137
// Connecting to ODBC server.
159138
ASSERT_EQ(SQL_ERROR,
160-
SQLDriverConnect(conn, NULL, &connect_str0[0],
139+
SQLDriverConnect(this->conn, NULL, &connect_str0[0],
161140
static_cast<SQLSMALLINT>(connect_str0.size()), out_str,
162141
kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT));
163142

@@ -171,16 +150,10 @@ TYPED_TEST(ErrorsTest, DISABLED_TestSQLGetDiagFieldWForConnectFailureNTS) {
171150
message_text[kOdbcBufferSize - 1] = '\0';
172151

173152
ASSERT_EQ(SQL_SUCCESS,
174-
SQLGetDiagField(SQL_HANDLE_DBC, conn, RECORD_1, SQL_DIAG_MESSAGE_TEXT,
153+
SQLGetDiagField(SQL_HANDLE_DBC, this->conn, RECORD_1, SQL_DIAG_MESSAGE_TEXT,
175154
message_text, SQL_NTS, &message_text_length));
176155

177156
EXPECT_GT(message_text_length, 100);
178-
179-
// Free connection handle
180-
ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DBC, conn));
181-
182-
// Free environment handle
183-
ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env));
184157
}
185158

186159
TYPED_TEST(ErrorsTest, TestSQLGetDiagFieldWForDescriptorFailureFromDriverManager) {
@@ -280,20 +253,7 @@ TYPED_TEST(ErrorsTest, TestSQLGetDiagRecForDescriptorFailureFromDriverManager) {
280253
EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DESC, descriptor));
281254
}
282255

283-
TYPED_TEST(ErrorsTest, TestSQLGetDiagRecForConnectFailure) {
284-
// ODBC Environment
285-
SQLHENV env;
286-
SQLHDBC conn;
287-
288-
// Allocate an environment handle
289-
ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env));
290-
291-
ASSERT_EQ(SQL_SUCCESS,
292-
SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0));
293-
294-
// Allocate a connection using alloc handle
295-
ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DBC, env, &conn));
296-
256+
TYPED_TEST(ErrorsHandleTest, TestSQLGetDiagRecForConnectFailure) {
297257
// Invalid connect string
298258
std::string connect_str = this->GetInvalidConnectionString();
299259

@@ -306,16 +266,17 @@ TYPED_TEST(ErrorsTest, TestSQLGetDiagRecForConnectFailure) {
306266

307267
// Connecting to ODBC server.
308268
ASSERT_EQ(SQL_ERROR,
309-
SQLDriverConnect(conn, NULL, &connect_str0[0],
269+
SQLDriverConnect(this->conn, NULL, &connect_str0[0],
310270
static_cast<SQLSMALLINT>(connect_str0.size()), out_str,
311271
kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT));
312272

313273
SQLWCHAR sql_state[6];
314274
SQLINTEGER native_error;
315275
SQLWCHAR message[kOdbcBufferSize];
316276
SQLSMALLINT message_length;
317-
ASSERT_EQ(SQL_SUCCESS, SQLGetDiagRec(SQL_HANDLE_DBC, conn, 1, sql_state, &native_error,
318-
message, kOdbcBufferSize, &message_length));
277+
ASSERT_EQ(SQL_SUCCESS,
278+
SQLGetDiagRec(SQL_HANDLE_DBC, this->conn, 1, sql_state, &native_error,
279+
message, ODBC_BUFFER_SIZE, &message_length));
319280

320281
EXPECT_GT(message_length, 120);
321282

@@ -324,12 +285,6 @@ TYPED_TEST(ErrorsTest, TestSQLGetDiagRecForConnectFailure) {
324285
EXPECT_EQ(std::wstring(L"28000"), std::wstring(sql_state));
325286

326287
EXPECT_TRUE(!std::wstring(message).empty());
327-
328-
// Free connection handle
329-
ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DBC, conn));
330-
331-
// Free environment handle
332-
ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env));
333288
}
334289

335290
TYPED_TEST(ErrorsTest, TestSQLGetDiagRecInputData) {

cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,16 @@ void FlightSQLOdbcV2RemoteTestBase::SetUp() {
169169
connected_ = true;
170170
}
171171

172+
void FlightSQLOdbcEnvConnHandleRemoteTestBase::SetUp() { AllocEnvConnHandles(); }
173+
174+
void FlightSQLOdbcEnvConnHandleRemoteTestBase::TearDown() {
175+
// Free connection handle
176+
EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DBC, conn));
177+
178+
// Free environment handle
179+
EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env));
180+
}
181+
172182
std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers) {
173183
// Lambda function to compare characters without case sensitivity.
174184
auto char_compare = [](const char& char1, const char& char2) {
@@ -357,6 +367,21 @@ void FlightSQLOdbcV2MockTestBase::SetUp() {
357367
connected_ = true;
358368
}
359369

370+
void FlightSQLOdbcEnvConnHandleMockTestBase::SetUp() {
371+
this->Initialize();
372+
AllocEnvConnHandles();
373+
}
374+
375+
void FlightSQLOdbcEnvConnHandleMockTestBase::TearDown() {
376+
// Free connection handle
377+
EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DBC, conn));
378+
379+
// Free environment handle
380+
EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env));
381+
382+
ASSERT_OK(server_->Shutdown());
383+
}
384+
360385
bool CompareConnPropertyMap(Connection::ConnPropertyMap map1,
361386
Connection::ConnPropertyMap map2) {
362387
if (map1.size() != map2.size()) return false;

cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,15 @@ class FlightSQLOdbcV2RemoteTestBase : public FlightSQLODBCRemoteTestBase {
9191
void SetUp() override;
9292
};
9393

94-
static constexpr std::string_view kAuthorizationHeader = "authorization";
95-
static constexpr std::string_view kBearerPrefix = "Bearer ";
96-
static constexpr std::string_view kTestToken = "t0k3n";
94+
class FlightSQLOdbcEnvConnHandleRemoteTestBase : public FlightSQLODBCRemoteTestBase {
95+
protected:
96+
void SetUp() override;
97+
void TearDown() override;
98+
};
99+
100+
static constexpr std::string_view authorization_header = "authorization";
101+
static constexpr std::string_view bearer_prefix = "Bearer ";
102+
static constexpr std::string_view test_token = "t0k3n";
97103

98104
std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers);
99105

@@ -158,7 +164,6 @@ class FlightSQLODBCMockTestBase : public FlightSQLODBCRemoteTestBase {
158164

159165
void TearDown() override;
160166

161-
private:
162167
std::shared_ptr<arrow::flight::sql::example::SQLiteFlightSqlServer> server_;
163168
};
164169

@@ -170,6 +175,12 @@ class FlightSQLOdbcV2MockTestBase : public FlightSQLODBCMockTestBase {
170175
void SetUp() override;
171176
};
172177

178+
class FlightSQLOdbcEnvConnHandleMockTestBase : public FlightSQLODBCMockTestBase {
179+
protected:
180+
void SetUp() override;
181+
void TearDown() override;
182+
};
183+
173184
/** ODBC read buffer size. */
174185
static constexpr int kOdbcBufferSize = 1024;
175186

0 commit comments

Comments
 (0)