Skip to content

Commit a09234a

Browse files
committed
Extract SQLExtendedFetch implementation
Fix comments and use `ARROW_UNUSED` Use nullptr explicitly Co-Authored-By: alinalibq <[email protected]>
1 parent f65ee2c commit a09234a

File tree

4 files changed

+127
-7
lines changed

4 files changed

+127
-7
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,8 +1076,33 @@ SQLRETURN SQLExtendedFetch(SQLHSTMT stmt, SQLUSMALLINT fetch_orientation,
10761076
<< ", row_count_ptr: " << static_cast<const void*>(row_count_ptr)
10771077
<< ", row_status_array: "
10781078
<< static_cast<const void*>(row_status_array);
1079-
// GH-47714 TODO: Implement SQLExtendedFetch
1080-
return SQL_INVALID_HANDLE;
1079+
1080+
using ODBC::ODBCDescriptor;
1081+
using ODBC::ODBCStatement;
1082+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
1083+
// Only SQL_FETCH_NEXT forward-only fetching orientation is supported,
1084+
// meaning the behavior of SQLExtendedFetch is same as SQLFetch.
1085+
if (fetch_orientation != SQL_FETCH_NEXT) {
1086+
throw DriverException("Optional feature not supported.", "HYC00");
1087+
}
1088+
// Ignore fetch_offset as it's not applicable to SQL_FETCH_NEXT
1089+
ARROW_UNUSED(fetch_offset);
1090+
1091+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
1092+
1093+
// The SQL_ROWSET_SIZE statement attribute specifies the number of rows in the
1094+
// rowset. Retrieve it from GetRowsetSize.
1095+
SQLULEN row_set_size = statement->GetRowsetSize();
1096+
ARROW_LOG(DEBUG) << "SQL_ROWSET_SIZE value for SQLExtendedFetch: " << row_set_size;
1097+
1098+
if (statement->Fetch(static_cast<size_t>(row_set_size), row_count_ptr,
1099+
row_status_array)) {
1100+
return SQL_SUCCESS;
1101+
} else {
1102+
// Reached the end of rowset
1103+
return SQL_NO_DATA;
1104+
}
1105+
});
10811106
}
10821107

10831108
SQLRETURN SQLFetchScroll(SQLHSTMT stmt, SQLSMALLINT fetch_orientation,

cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,8 @@ void ODBCStatement::ExecuteDirect(const std::string& query) {
316316
is_prepared_ = false;
317317
}
318318

319-
bool ODBCStatement::Fetch(size_t rows) {
319+
bool ODBCStatement::Fetch(size_t rows, SQLULEN* row_count_ptr,
320+
SQLUSMALLINT* row_status_array) {
320321
if (has_reached_end_of_result_) {
321322
ird_->SetRowsProcessed(0);
322323
return false;
@@ -349,11 +350,24 @@ bool ODBCStatement::Fetch(size_t rows) {
349350
current_ard_->NotifyBindingsHavePropagated();
350351
}
351352

352-
size_t rows_fetched = current_result_->Move(rows, current_ard_->GetBindOffset(),
353-
current_ard_->GetBoundStructOffset(),
354-
ird_->GetArrayStatusPtr());
353+
uint16_t* array_status_ptr;
354+
if (row_status_array) {
355+
// For SQLExtendedFetch only
356+
array_status_ptr = row_status_array;
357+
} else {
358+
array_status_ptr = ird_->GetArrayStatusPtr();
359+
}
360+
361+
size_t rows_fetched =
362+
current_result_->Move(rows, current_ard_->GetBindOffset(),
363+
current_ard_->GetBoundStructOffset(), array_status_ptr);
355364
ird_->SetRowsProcessed(static_cast<SQLULEN>(rows_fetched));
356365

366+
if (row_count_ptr) {
367+
// For SQLExtendedFetch only
368+
*row_count_ptr = rows_fetched;
369+
}
370+
357371
row_number_ += rows_fetched;
358372
has_reached_end_of_result_ = rows_fetched != rows;
359373
return rows_fetched != 0;

cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ class ODBCStatement : public ODBCHandle<ODBCStatement> {
5959
void ExecuteDirect(const std::string& query);
6060

6161
/// \brief Return true if the number of rows fetch was greater than zero.
62-
bool Fetch(size_t rows);
62+
///
63+
/// row_count_ptr and row_status_array are optional arguments, they are only needed for
64+
/// SQLExtendedFetch
65+
bool Fetch(size_t rows, SQLULEN* row_count_ptr = 0, SQLUSMALLINT* row_status_array = 0);
66+
6367
bool IsPrepared() const;
6468

6569
void GetStmtAttr(SQLINTEGER statement_attribute, SQLPOINTER output,

cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,6 +1722,83 @@ TYPED_TEST(StatementTest, TestSQLBindColIndicatorOnlySQLUnbind) {
17221722
// EXPECT_EQ(1, char_val_ind);
17231723
}
17241724

1725+
TYPED_TEST(StatementTest, TestSQLExtendedFetchRowFetching) {
1726+
// Set SQL_ROWSET_SIZE to fetch 3 rows at once
1727+
1728+
constexpr SQLULEN rows = 3;
1729+
SQLINTEGER val[rows];
1730+
SQLLEN buf_len = sizeof(val);
1731+
SQLLEN ind[rows];
1732+
1733+
// Same variable will be used for column 1, the value of `val`
1734+
// should be updated after every SQLFetch call.
1735+
ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, val, buf_len, ind));
1736+
1737+
ASSERT_EQ(SQL_SUCCESS, SQLSetStmtAttr(this->stmt, SQL_ROWSET_SIZE,
1738+
reinterpret_cast<SQLPOINTER>(rows), 0));
1739+
1740+
std::wstring wsql =
1741+
LR"(
1742+
SELECT 1 AS small_table
1743+
UNION ALL
1744+
SELECT 2
1745+
UNION ALL
1746+
SELECT 3;
1747+
)";
1748+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
1749+
1750+
ASSERT_EQ(SQL_SUCCESS,
1751+
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
1752+
1753+
// Fetch row 1-3.
1754+
SQLULEN row_count;
1755+
SQLUSMALLINT row_status[rows];
1756+
1757+
ASSERT_EQ(SQL_SUCCESS,
1758+
SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count, row_status));
1759+
EXPECT_EQ(3, row_count);
1760+
1761+
for (int i = 0; i < rows; i++) {
1762+
EXPECT_EQ(SQL_SUCCESS, row_status[i]);
1763+
}
1764+
1765+
// Verify 1 is returned for row 1
1766+
EXPECT_EQ(1, val[0]);
1767+
// Verify 2 is returned for row 2
1768+
EXPECT_EQ(2, val[1]);
1769+
// Verify 3 is returned for row 3
1770+
EXPECT_EQ(3, val[2]);
1771+
1772+
// Verify result set has no more data beyond row 3
1773+
SQLULEN row_count2;
1774+
SQLUSMALLINT row_status2[rows];
1775+
EXPECT_EQ(SQL_NO_DATA,
1776+
SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count2, row_status2));
1777+
}
1778+
1779+
TEST_F(StatementRemoteTest, DISABLED_TestSQLExtendedFetchQueryNullIndicator) {
1780+
// GH-47110: SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for 22002
1781+
// Limitation on mock test server prevents null from working properly, so use remote
1782+
// server instead. Mock server has type `DENSE_UNION` for null column data.
1783+
SQLINTEGER val;
1784+
1785+
ASSERT_EQ(SQL_SUCCESS, SQLBindCol(this->stmt, 1, SQL_C_LONG, &val, 0, nullptr));
1786+
1787+
std::wstring wsql = L"SELECT null as null_col;";
1788+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
1789+
1790+
ASSERT_EQ(SQL_SUCCESS,
1791+
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
1792+
1793+
SQLULEN row_count1;
1794+
SQLUSMALLINT row_status1[1];
1795+
1796+
// SQLExtendedFetch should return SQL_SUCCESS_WITH_INFO for 22002 state
1797+
ASSERT_EQ(SQL_SUCCESS_WITH_INFO,
1798+
SQLExtendedFetch(this->stmt, SQL_FETCH_NEXT, 0, &row_count1, row_status1));
1799+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState22002);
1800+
}
1801+
17251802
TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputString) {
17261803
SQLWCHAR buf[1024];
17271804
SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize();

0 commit comments

Comments
 (0)