Skip to content

Commit b22c359

Browse files
committed
Extract SQLRowCount implementation
Update doc Co-Authored-By: alinalibq <[email protected]>
1 parent b2e8f25 commit b22c359

File tree

4 files changed

+67
-2
lines changed

4 files changed

+67
-2
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,8 +1126,13 @@ SQLRETURN SQLNumResultCols(SQLHSTMT stmt, SQLSMALLINT* column_count_ptr) {
11261126
SQLRETURN SQLRowCount(SQLHSTMT stmt, SQLLEN* row_count_ptr) {
11271127
ARROW_LOG(DEBUG) << "SQLRowCount called with stmt: " << stmt
11281128
<< ", column_count_ptr: " << static_cast<const void*>(row_count_ptr);
1129-
// GH-47713 TODO: Implement SQLRowCount
1130-
return SQL_INVALID_HANDLE;
1129+
1130+
using ODBC::ODBCStatement;
1131+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
1132+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
1133+
statement->GetRowCount(row_count_ptr);
1134+
return SQL_SUCCESS;
1135+
});
11311136
}
11321137

11331138
SQLRETURN SQLTables(SQLHSTMT stmt, SQLWCHAR* catalog_name,

cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,17 @@ bool ODBCStatement::GetData(SQLSMALLINT record_number, SQLSMALLINT c_type,
751751
data_ptr, buffer_length, indicator_ptr);
752752
}
753753

754+
void ODBCStatement::GetRowCount(SQLLEN* row_count_ptr) {
755+
if (!row_count_ptr) {
756+
// row count pointer is not valid, do nothing as ODBC spec does not mention this as an
757+
// error
758+
return;
759+
}
760+
// Will always be -1 (meaning number of rows unknown) since only SELECT is supported by
761+
// driver
762+
*row_count_ptr = -1;
763+
}
764+
754765
void ODBCStatement::ReleaseStatement() {
755766
CloseCursor(true);
756767
connection_.DropStatement(this);

cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ class ODBCStatement : public ODBCHandle<ODBCStatement> {
8080
bool GetData(SQLSMALLINT record_number, SQLSMALLINT c_type, SQLPOINTER data_ptr,
8181
SQLLEN buffer_length, SQLLEN* indicator_ptr);
8282

83+
/// \brief Return number of rows affected by an UPDATE, INSERT, or DELETE statement\
84+
///
85+
/// -1 is returned as driver only supports SELECT statement
86+
void GetRowCount(SQLLEN* row_count_ptr);
87+
8388
/**
8489
* @brief Closes the cursor. This does _not_ un-prepare the statement or change
8590
* bindings.

cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,4 +217,48 @@ TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsErrorOnBadInputs) {
217217
VerifyOdbcErrorState(SQL_HANDLE_DBC, this->conn, kErrorStateHY090);
218218
}
219219

220+
TYPED_TEST(StatementTest, SQLRowCountReturnsNegativeOneOnSelect) {
221+
SQLLEN row_count = 0;
222+
SQLLEN expected_value = -1;
223+
SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3";
224+
SQLINTEGER query_length = static_cast<SQLINTEGER>(wcslen(sql_query));
225+
226+
ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length));
227+
228+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
229+
230+
CheckIntColumn(this->stmt, 1, 1);
231+
CheckStringColumnW(this->stmt, 2, L"One");
232+
CheckIntColumn(this->stmt, 3, 3);
233+
234+
ASSERT_EQ(SQL_SUCCESS, SQLRowCount(this->stmt, &row_count));
235+
236+
EXPECT_EQ(expected_value, row_count);
237+
}
238+
239+
TYPED_TEST(StatementTest, SQLRowCountReturnsSuccessOnNullptr) {
240+
SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3";
241+
SQLINTEGER query_length = static_cast<SQLINTEGER>(wcslen(sql_query));
242+
243+
ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length));
244+
245+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
246+
247+
CheckIntColumn(this->stmt, 1, 1);
248+
CheckStringColumnW(this->stmt, 2, L"One");
249+
CheckIntColumn(this->stmt, 3, 3);
250+
251+
ASSERT_EQ(SQL_SUCCESS, SQLRowCount(this->stmt, nullptr));
252+
}
253+
254+
TYPED_TEST(StatementTest, SQLRowCountFunctionSequenceErrorOnNoQuery) {
255+
SQLLEN row_count = 0;
256+
SQLLEN expected_value = 0;
257+
258+
ASSERT_EQ(SQL_ERROR, SQLRowCount(this->stmt, &row_count));
259+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010);
260+
261+
EXPECT_EQ(expected_value, row_count);
262+
}
263+
220264
} // namespace arrow::flight::sql::odbc

0 commit comments

Comments
 (0)