Skip to content

Commit d0e098f

Browse files
authored
GH-46575: [C++][FlightRPC] ODBC Diagnostics Report (#47763)
### Rationale for this change ODBC needs to provide diagnostic information so users can debug the error ### What changes are included in this PR? - Implementation of SQLGetDiagField and SQLGetDiagRec Tests are included in separate PR (see #47764) ### Are these changes tested? Tests will be in a separate PR (see #47764). Other APIs depend on SQLGetDiagField and SQLGetDiagRec to get error reporting functionality, and tests for SQLGetDiagField and SQLGetDiagRec depend on other APIs for creating errors, as these diagnostic APIs alone do not initiate any errors. Changes tested locally ### Are there any user-facing changes? No * GitHub Issue: #46575 Authored-by: Alina (Xi) Li <[email protected]> Signed-off-by: Sutou Kouhei <[email protected]>
1 parent 06f53b2 commit d0e098f

File tree

2 files changed

+444
-79
lines changed

2 files changed

+444
-79
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 353 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,20 @@ SQLRETURN SQLFreeStmt(SQLHSTMT handle, SQLUSMALLINT option) {
218218
return SQL_INVALID_HANDLE;
219219
}
220220

221+
inline bool IsValidStringFieldArgs(SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length,
222+
SQLSMALLINT* string_length_ptr, bool is_unicode) {
223+
const SQLSMALLINT char_size = is_unicode ? GetSqlWCharSize() : sizeof(char);
224+
const bool has_valid_buffer =
225+
buffer_length == SQL_NTS || (buffer_length >= 0 && buffer_length % char_size == 0);
226+
227+
// regardless of capacity return false if invalid
228+
if (diag_info_ptr && !has_valid_buffer) {
229+
return false;
230+
}
231+
232+
return has_valid_buffer || string_length_ptr;
233+
}
234+
221235
SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle,
222236
SQLSMALLINT rec_number, SQLSMALLINT diag_identifier,
223237
SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length,
@@ -229,8 +243,258 @@ SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle,
229243
<< ", diag_info_ptr: " << diag_info_ptr
230244
<< ", buffer_length: " << buffer_length << ", string_length_ptr: "
231245
<< static_cast<const void*>(string_length_ptr);
232-
// GH-46575 TODO: Implement SQLGetDiagField
233-
return SQL_INVALID_HANDLE;
246+
// GH-46575 TODO: Add tests for SQLGetDiagField
247+
using ODBC::GetStringAttribute;
248+
using ODBC::ODBCConnection;
249+
using ODBC::ODBCDescriptor;
250+
using ODBC::ODBCEnvironment;
251+
using ODBC::ODBCStatement;
252+
253+
if (!handle) {
254+
return SQL_INVALID_HANDLE;
255+
}
256+
257+
if (!diag_info_ptr && !string_length_ptr) {
258+
return SQL_ERROR;
259+
}
260+
261+
// If buffer length derived from null terminated string
262+
if (diag_info_ptr && buffer_length == SQL_NTS) {
263+
const wchar_t* str = reinterpret_cast<wchar_t*>(diag_info_ptr);
264+
buffer_length = wcslen(str) * GetSqlWCharSize();
265+
}
266+
267+
// Set character type to be Unicode by default
268+
const bool is_unicode = true;
269+
Diagnostics* diagnostics = nullptr;
270+
271+
switch (handle_type) {
272+
case SQL_HANDLE_ENV: {
273+
ODBCEnvironment* environment = reinterpret_cast<ODBCEnvironment*>(handle);
274+
diagnostics = &environment->GetDiagnostics();
275+
break;
276+
}
277+
278+
case SQL_HANDLE_DBC: {
279+
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(handle);
280+
diagnostics = &connection->GetDiagnostics();
281+
break;
282+
}
283+
284+
case SQL_HANDLE_DESC: {
285+
ODBCDescriptor* descriptor = reinterpret_cast<ODBCDescriptor*>(handle);
286+
diagnostics = &descriptor->GetDiagnostics();
287+
break;
288+
}
289+
290+
case SQL_HANDLE_STMT: {
291+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(handle);
292+
diagnostics = &statement->GetDiagnostics();
293+
break;
294+
}
295+
296+
default:
297+
return SQL_ERROR;
298+
}
299+
300+
if (!diagnostics) {
301+
return SQL_ERROR;
302+
}
303+
304+
// Retrieve and return if header level diagnostics
305+
switch (diag_identifier) {
306+
case SQL_DIAG_NUMBER: {
307+
if (diag_info_ptr) {
308+
*static_cast<SQLINTEGER*>(diag_info_ptr) =
309+
static_cast<SQLINTEGER>(diagnostics->GetRecordCount());
310+
}
311+
312+
if (string_length_ptr) {
313+
*string_length_ptr = sizeof(SQLINTEGER);
314+
}
315+
316+
return SQL_SUCCESS;
317+
}
318+
319+
// Driver manager implements SQL_DIAG_RETURNCODE
320+
case SQL_DIAG_RETURNCODE: {
321+
return SQL_SUCCESS;
322+
}
323+
324+
case SQL_DIAG_CURSOR_ROW_COUNT: {
325+
if (handle_type == SQL_HANDLE_STMT) {
326+
if (diag_info_ptr) {
327+
// Will always be 0 if only SELECT supported
328+
*static_cast<SQLLEN*>(diag_info_ptr) = 0;
329+
}
330+
331+
if (string_length_ptr) {
332+
*string_length_ptr = sizeof(SQLLEN);
333+
}
334+
335+
return SQL_SUCCESS;
336+
}
337+
338+
return SQL_ERROR;
339+
}
340+
341+
// Not supported
342+
case SQL_DIAG_DYNAMIC_FUNCTION:
343+
case SQL_DIAG_DYNAMIC_FUNCTION_CODE: {
344+
if (handle_type == SQL_HANDLE_STMT) {
345+
return SQL_SUCCESS;
346+
}
347+
348+
return SQL_ERROR;
349+
}
350+
351+
case SQL_DIAG_ROW_COUNT: {
352+
if (handle_type == SQL_HANDLE_STMT) {
353+
if (diag_info_ptr) {
354+
// Will always be 0 if only SELECT is supported
355+
*static_cast<SQLLEN*>(diag_info_ptr) = 0;
356+
}
357+
358+
if (string_length_ptr) {
359+
*string_length_ptr = sizeof(SQLLEN);
360+
}
361+
362+
return SQL_SUCCESS;
363+
}
364+
365+
return SQL_ERROR;
366+
}
367+
}
368+
369+
// If not a diagnostic header field then the record number must be 1 or greater
370+
if (rec_number < 1) {
371+
return SQL_ERROR;
372+
}
373+
374+
// Retrieve record level diagnostics from specified 1 based record
375+
const uint32_t record_index = static_cast<uint32_t>(rec_number - 1);
376+
if (!diagnostics->HasRecord(record_index)) {
377+
return SQL_NO_DATA;
378+
}
379+
380+
// Retrieve record field data
381+
switch (diag_identifier) {
382+
case SQL_DIAG_MESSAGE_TEXT: {
383+
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
384+
is_unicode)) {
385+
const std::string& message = diagnostics->GetMessageText(record_index);
386+
return GetStringAttribute(is_unicode, message, true, diag_info_ptr, buffer_length,
387+
string_length_ptr, *diagnostics);
388+
}
389+
390+
return SQL_ERROR;
391+
}
392+
393+
case SQL_DIAG_NATIVE: {
394+
if (diag_info_ptr) {
395+
*static_cast<SQLINTEGER*>(diag_info_ptr) =
396+
diagnostics->GetNativeError(record_index);
397+
}
398+
399+
if (string_length_ptr) {
400+
*string_length_ptr = sizeof(SQLINTEGER);
401+
}
402+
403+
return SQL_SUCCESS;
404+
}
405+
406+
case SQL_DIAG_SERVER_NAME: {
407+
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
408+
is_unicode)) {
409+
switch (handle_type) {
410+
case SQL_HANDLE_DBC: {
411+
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(handle);
412+
std::string dsn = connection->GetDSN();
413+
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length,
414+
string_length_ptr, *diagnostics);
415+
}
416+
417+
case SQL_HANDLE_DESC: {
418+
ODBCDescriptor* descriptor = reinterpret_cast<ODBCDescriptor*>(handle);
419+
ODBCConnection* connection = &descriptor->GetConnection();
420+
std::string dsn = connection->GetDSN();
421+
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length,
422+
string_length_ptr, *diagnostics);
423+
break;
424+
}
425+
426+
case SQL_HANDLE_STMT: {
427+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(handle);
428+
ODBCConnection* connection = &statement->GetConnection();
429+
std::string dsn = connection->GetDSN();
430+
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length,
431+
string_length_ptr, *diagnostics);
432+
}
433+
434+
default:
435+
return SQL_ERROR;
436+
}
437+
}
438+
439+
return SQL_ERROR;
440+
}
441+
442+
case SQL_DIAG_SQLSTATE: {
443+
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
444+
is_unicode)) {
445+
const std::string& state = diagnostics->GetSQLState(record_index);
446+
return GetStringAttribute(is_unicode, state, true, diag_info_ptr, buffer_length,
447+
string_length_ptr, *diagnostics);
448+
}
449+
450+
return SQL_ERROR;
451+
}
452+
453+
// Return valid dummy variable for unimplemented field
454+
case SQL_DIAG_COLUMN_NUMBER: {
455+
if (diag_info_ptr) {
456+
*static_cast<SQLINTEGER*>(diag_info_ptr) = SQL_NO_COLUMN_NUMBER;
457+
}
458+
459+
if (string_length_ptr) {
460+
*string_length_ptr = sizeof(SQLINTEGER);
461+
}
462+
463+
return SQL_SUCCESS;
464+
}
465+
466+
// Return empty string dummy variable for unimplemented fields
467+
case SQL_DIAG_CLASS_ORIGIN:
468+
case SQL_DIAG_CONNECTION_NAME:
469+
case SQL_DIAG_SUBCLASS_ORIGIN: {
470+
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
471+
is_unicode)) {
472+
return GetStringAttribute(is_unicode, "", true, diag_info_ptr, buffer_length,
473+
string_length_ptr, *diagnostics);
474+
}
475+
476+
return SQL_ERROR;
477+
}
478+
479+
// Return valid dummy variable for unimplemented field
480+
case SQL_DIAG_ROW_NUMBER: {
481+
if (diag_info_ptr) {
482+
*static_cast<SQLLEN*>(diag_info_ptr) = SQL_NO_ROW_NUMBER;
483+
}
484+
485+
if (string_length_ptr) {
486+
*string_length_ptr = sizeof(SQLLEN);
487+
}
488+
489+
return SQL_SUCCESS;
490+
}
491+
492+
default: {
493+
return SQL_ERROR;
494+
}
495+
}
496+
497+
return SQL_ERROR;
234498
}
235499

236500
SQLRETURN SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT rec_number,
@@ -244,8 +508,93 @@ SQLRETURN SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT r
244508
<< ", message_text: " << static_cast<const void*>(message_text)
245509
<< ", buffer_length: " << buffer_length
246510
<< ", text_length_ptr: " << static_cast<const void*>(text_length_ptr);
247-
// GH-46575 TODO: Implement SQLGetDiagRec
248-
return SQL_INVALID_HANDLE;
511+
// GH-46575 TODO: Add tests for SQLGetDiagRec
512+
using arrow::flight::sql::odbc::Diagnostics;
513+
using ODBC::GetStringAttribute;
514+
using ODBC::ODBCConnection;
515+
using ODBC::ODBCDescriptor;
516+
using ODBC::ODBCEnvironment;
517+
using ODBC::ODBCStatement;
518+
519+
if (!handle) {
520+
return SQL_INVALID_HANDLE;
521+
}
522+
523+
// Record number must be greater or equal to 1
524+
if (rec_number < 1 || buffer_length < 0) {
525+
return SQL_ERROR;
526+
}
527+
528+
// Set character type to be Unicode by default
529+
const bool is_unicode = true;
530+
Diagnostics* diagnostics = nullptr;
531+
532+
switch (handle_type) {
533+
case SQL_HANDLE_ENV: {
534+
auto* environment = ODBCEnvironment::Of(handle);
535+
diagnostics = &environment->GetDiagnostics();
536+
break;
537+
}
538+
539+
case SQL_HANDLE_DBC: {
540+
auto* connection = ODBCConnection::Of(handle);
541+
diagnostics = &connection->GetDiagnostics();
542+
break;
543+
}
544+
545+
case SQL_HANDLE_DESC: {
546+
auto* descriptor = ODBCDescriptor::Of(handle);
547+
diagnostics = &descriptor->GetDiagnostics();
548+
break;
549+
}
550+
551+
case SQL_HANDLE_STMT: {
552+
auto* statement = ODBCStatement::Of(handle);
553+
diagnostics = &statement->GetDiagnostics();
554+
break;
555+
}
556+
557+
default:
558+
return SQL_INVALID_HANDLE;
559+
}
560+
561+
if (!diagnostics) {
562+
return SQL_ERROR;
563+
}
564+
565+
// Convert from ODBC 1 based record number to internal diagnostics 0 indexed storage
566+
const size_t record_index = static_cast<size_t>(rec_number - 1);
567+
if (!diagnostics->HasRecord(record_index)) {
568+
return SQL_NO_DATA;
569+
}
570+
571+
if (sql_state) {
572+
// The length of the sql state is always 5 characters plus null
573+
SQLSMALLINT size = 6;
574+
const std::string& state = diagnostics->GetSQLState(record_index);
575+
576+
// Microsoft documentation does not mention
577+
// any SQLGetDiagRec return value that is associated with `sql_state` buffer, so
578+
// the return value for writing `sql_state` buffer is ignored by the driver.
579+
ARROW_UNUSED(GetStringAttribute(is_unicode, state, false, sql_state, size, &size,
580+
*diagnostics));
581+
}
582+
583+
if (native_error_ptr) {
584+
*native_error_ptr = diagnostics->GetNativeError(record_index);
585+
}
586+
587+
if (message_text || text_length_ptr) {
588+
const std::string& message = diagnostics->GetMessageText(record_index);
589+
590+
// According to Microsoft documentation,
591+
// SQL_SUCCESS_WITH_INFO should be returned if `*message_text` buffer was too
592+
// small to hold the requested diagnostic message.
593+
return GetStringAttribute(is_unicode, message, false, message_text, buffer_length,
594+
text_length_ptr, *diagnostics);
595+
}
596+
597+
return SQL_SUCCESS;
249598
}
250599

251600
SQLRETURN SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER value_ptr,

0 commit comments

Comments
 (0)