@@ -218,6 +218,20 @@ SQLRETURN SQLFreeStmt(SQLHSTMT handle, SQLUSMALLINT option) {
218218 return SQL_INVALID_HANDLE;
219219}
220220
221+ inline bool IsValidStringFieldArgs (SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length,
222+ SQLSMALLINT* string_length_ptr, bool is_unicode) {
223+ const SQLSMALLINT char_size = is_unicode ? GetSqlWCharSize () : sizeof (char );
224+ const bool has_valid_buffer =
225+ buffer_length == SQL_NTS || (buffer_length >= 0 && buffer_length % char_size == 0 );
226+
227+ // regardless of capacity return false if invalid
228+ if (diag_info_ptr && !has_valid_buffer) {
229+ return false ;
230+ }
231+
232+ return has_valid_buffer || string_length_ptr;
233+ }
234+
221235SQLRETURN SQLGetDiagField (SQLSMALLINT handle_type, SQLHANDLE handle,
222236 SQLSMALLINT rec_number, SQLSMALLINT diag_identifier,
223237 SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length,
@@ -229,8 +243,258 @@ SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle,
229243 << " , diag_info_ptr: " << diag_info_ptr
230244 << " , buffer_length: " << buffer_length << " , string_length_ptr: "
231245 << static_cast <const void *>(string_length_ptr);
232- // GH-46575 TODO: Implement SQLGetDiagField
233- return SQL_INVALID_HANDLE;
246+ // GH-46575 TODO: Add tests for SQLGetDiagField
247+ using ODBC::GetStringAttribute;
248+ using ODBC::ODBCConnection;
249+ using ODBC::ODBCDescriptor;
250+ using ODBC::ODBCEnvironment;
251+ using ODBC::ODBCStatement;
252+
253+ if (!handle) {
254+ return SQL_INVALID_HANDLE;
255+ }
256+
257+ if (!diag_info_ptr && !string_length_ptr) {
258+ return SQL_ERROR;
259+ }
260+
261+ // If buffer length derived from null terminated string
262+ if (diag_info_ptr && buffer_length == SQL_NTS) {
263+ const wchar_t * str = reinterpret_cast <wchar_t *>(diag_info_ptr);
264+ buffer_length = wcslen (str) * GetSqlWCharSize ();
265+ }
266+
267+ // Set character type to be Unicode by default
268+ const bool is_unicode = true ;
269+ Diagnostics* diagnostics = nullptr ;
270+
271+ switch (handle_type) {
272+ case SQL_HANDLE_ENV: {
273+ ODBCEnvironment* environment = reinterpret_cast <ODBCEnvironment*>(handle);
274+ diagnostics = &environment->GetDiagnostics ();
275+ break ;
276+ }
277+
278+ case SQL_HANDLE_DBC: {
279+ ODBCConnection* connection = reinterpret_cast <ODBCConnection*>(handle);
280+ diagnostics = &connection->GetDiagnostics ();
281+ break ;
282+ }
283+
284+ case SQL_HANDLE_DESC: {
285+ ODBCDescriptor* descriptor = reinterpret_cast <ODBCDescriptor*>(handle);
286+ diagnostics = &descriptor->GetDiagnostics ();
287+ break ;
288+ }
289+
290+ case SQL_HANDLE_STMT: {
291+ ODBCStatement* statement = reinterpret_cast <ODBCStatement*>(handle);
292+ diagnostics = &statement->GetDiagnostics ();
293+ break ;
294+ }
295+
296+ default :
297+ return SQL_ERROR;
298+ }
299+
300+ if (!diagnostics) {
301+ return SQL_ERROR;
302+ }
303+
304+ // Retrieve and return if header level diagnostics
305+ switch (diag_identifier) {
306+ case SQL_DIAG_NUMBER: {
307+ if (diag_info_ptr) {
308+ *static_cast <SQLINTEGER*>(diag_info_ptr) =
309+ static_cast <SQLINTEGER>(diagnostics->GetRecordCount ());
310+ }
311+
312+ if (string_length_ptr) {
313+ *string_length_ptr = sizeof (SQLINTEGER);
314+ }
315+
316+ return SQL_SUCCESS;
317+ }
318+
319+ // Driver manager implements SQL_DIAG_RETURNCODE
320+ case SQL_DIAG_RETURNCODE: {
321+ return SQL_SUCCESS;
322+ }
323+
324+ case SQL_DIAG_CURSOR_ROW_COUNT: {
325+ if (handle_type == SQL_HANDLE_STMT) {
326+ if (diag_info_ptr) {
327+ // Will always be 0 if only SELECT supported
328+ *static_cast <SQLLEN*>(diag_info_ptr) = 0 ;
329+ }
330+
331+ if (string_length_ptr) {
332+ *string_length_ptr = sizeof (SQLLEN);
333+ }
334+
335+ return SQL_SUCCESS;
336+ }
337+
338+ return SQL_ERROR;
339+ }
340+
341+ // Not supported
342+ case SQL_DIAG_DYNAMIC_FUNCTION:
343+ case SQL_DIAG_DYNAMIC_FUNCTION_CODE: {
344+ if (handle_type == SQL_HANDLE_STMT) {
345+ return SQL_SUCCESS;
346+ }
347+
348+ return SQL_ERROR;
349+ }
350+
351+ case SQL_DIAG_ROW_COUNT: {
352+ if (handle_type == SQL_HANDLE_STMT) {
353+ if (diag_info_ptr) {
354+ // Will always be 0 if only SELECT is supported
355+ *static_cast <SQLLEN*>(diag_info_ptr) = 0 ;
356+ }
357+
358+ if (string_length_ptr) {
359+ *string_length_ptr = sizeof (SQLLEN);
360+ }
361+
362+ return SQL_SUCCESS;
363+ }
364+
365+ return SQL_ERROR;
366+ }
367+ }
368+
369+ // If not a diagnostic header field then the record number must be 1 or greater
370+ if (rec_number < 1 ) {
371+ return SQL_ERROR;
372+ }
373+
374+ // Retrieve record level diagnostics from specified 1 based record
375+ const uint32_t record_index = static_cast <uint32_t >(rec_number - 1 );
376+ if (!diagnostics->HasRecord (record_index)) {
377+ return SQL_NO_DATA;
378+ }
379+
380+ // Retrieve record field data
381+ switch (diag_identifier) {
382+ case SQL_DIAG_MESSAGE_TEXT: {
383+ if (IsValidStringFieldArgs (diag_info_ptr, buffer_length, string_length_ptr,
384+ is_unicode)) {
385+ const std::string& message = diagnostics->GetMessageText (record_index);
386+ return GetStringAttribute (is_unicode, message, true , diag_info_ptr, buffer_length,
387+ string_length_ptr, *diagnostics);
388+ }
389+
390+ return SQL_ERROR;
391+ }
392+
393+ case SQL_DIAG_NATIVE: {
394+ if (diag_info_ptr) {
395+ *static_cast <SQLINTEGER*>(diag_info_ptr) =
396+ diagnostics->GetNativeError (record_index);
397+ }
398+
399+ if (string_length_ptr) {
400+ *string_length_ptr = sizeof (SQLINTEGER);
401+ }
402+
403+ return SQL_SUCCESS;
404+ }
405+
406+ case SQL_DIAG_SERVER_NAME: {
407+ if (IsValidStringFieldArgs (diag_info_ptr, buffer_length, string_length_ptr,
408+ is_unicode)) {
409+ switch (handle_type) {
410+ case SQL_HANDLE_DBC: {
411+ ODBCConnection* connection = reinterpret_cast <ODBCConnection*>(handle);
412+ std::string dsn = connection->GetDSN ();
413+ return GetStringAttribute (is_unicode, dsn, true , diag_info_ptr, buffer_length,
414+ string_length_ptr, *diagnostics);
415+ }
416+
417+ case SQL_HANDLE_DESC: {
418+ ODBCDescriptor* descriptor = reinterpret_cast <ODBCDescriptor*>(handle);
419+ ODBCConnection* connection = &descriptor->GetConnection ();
420+ std::string dsn = connection->GetDSN ();
421+ return GetStringAttribute (is_unicode, dsn, true , diag_info_ptr, buffer_length,
422+ string_length_ptr, *diagnostics);
423+ break ;
424+ }
425+
426+ case SQL_HANDLE_STMT: {
427+ ODBCStatement* statement = reinterpret_cast <ODBCStatement*>(handle);
428+ ODBCConnection* connection = &statement->GetConnection ();
429+ std::string dsn = connection->GetDSN ();
430+ return GetStringAttribute (is_unicode, dsn, true , diag_info_ptr, buffer_length,
431+ string_length_ptr, *diagnostics);
432+ }
433+
434+ default :
435+ return SQL_ERROR;
436+ }
437+ }
438+
439+ return SQL_ERROR;
440+ }
441+
442+ case SQL_DIAG_SQLSTATE: {
443+ if (IsValidStringFieldArgs (diag_info_ptr, buffer_length, string_length_ptr,
444+ is_unicode)) {
445+ const std::string& state = diagnostics->GetSQLState (record_index);
446+ return GetStringAttribute (is_unicode, state, true , diag_info_ptr, buffer_length,
447+ string_length_ptr, *diagnostics);
448+ }
449+
450+ return SQL_ERROR;
451+ }
452+
453+ // Return valid dummy variable for unimplemented field
454+ case SQL_DIAG_COLUMN_NUMBER: {
455+ if (diag_info_ptr) {
456+ *static_cast <SQLINTEGER*>(diag_info_ptr) = SQL_NO_COLUMN_NUMBER;
457+ }
458+
459+ if (string_length_ptr) {
460+ *string_length_ptr = sizeof (SQLINTEGER);
461+ }
462+
463+ return SQL_SUCCESS;
464+ }
465+
466+ // Return empty string dummy variable for unimplemented fields
467+ case SQL_DIAG_CLASS_ORIGIN:
468+ case SQL_DIAG_CONNECTION_NAME:
469+ case SQL_DIAG_SUBCLASS_ORIGIN: {
470+ if (IsValidStringFieldArgs (diag_info_ptr, buffer_length, string_length_ptr,
471+ is_unicode)) {
472+ return GetStringAttribute (is_unicode, " " , true , diag_info_ptr, buffer_length,
473+ string_length_ptr, *diagnostics);
474+ }
475+
476+ return SQL_ERROR;
477+ }
478+
479+ // Return valid dummy variable for unimplemented field
480+ case SQL_DIAG_ROW_NUMBER: {
481+ if (diag_info_ptr) {
482+ *static_cast <SQLLEN*>(diag_info_ptr) = SQL_NO_ROW_NUMBER;
483+ }
484+
485+ if (string_length_ptr) {
486+ *string_length_ptr = sizeof (SQLLEN);
487+ }
488+
489+ return SQL_SUCCESS;
490+ }
491+
492+ default : {
493+ return SQL_ERROR;
494+ }
495+ }
496+
497+ return SQL_ERROR;
234498}
235499
236500SQLRETURN SQLGetDiagRec (SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT rec_number,
@@ -244,8 +508,93 @@ SQLRETURN SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT r
244508 << " , message_text: " << static_cast <const void *>(message_text)
245509 << " , buffer_length: " << buffer_length
246510 << " , text_length_ptr: " << static_cast <const void *>(text_length_ptr);
247- // GH-46575 TODO: Implement SQLGetDiagRec
248- return SQL_INVALID_HANDLE;
511+ // GH-46575 TODO: Add tests for SQLGetDiagRec
512+ using arrow::flight::sql::odbc::Diagnostics;
513+ using ODBC::GetStringAttribute;
514+ using ODBC::ODBCConnection;
515+ using ODBC::ODBCDescriptor;
516+ using ODBC::ODBCEnvironment;
517+ using ODBC::ODBCStatement;
518+
519+ if (!handle) {
520+ return SQL_INVALID_HANDLE;
521+ }
522+
523+ // Record number must be greater or equal to 1
524+ if (rec_number < 1 || buffer_length < 0 ) {
525+ return SQL_ERROR;
526+ }
527+
528+ // Set character type to be Unicode by default
529+ const bool is_unicode = true ;
530+ Diagnostics* diagnostics = nullptr ;
531+
532+ switch (handle_type) {
533+ case SQL_HANDLE_ENV: {
534+ auto * environment = ODBCEnvironment::Of (handle);
535+ diagnostics = &environment->GetDiagnostics ();
536+ break ;
537+ }
538+
539+ case SQL_HANDLE_DBC: {
540+ auto * connection = ODBCConnection::Of (handle);
541+ diagnostics = &connection->GetDiagnostics ();
542+ break ;
543+ }
544+
545+ case SQL_HANDLE_DESC: {
546+ auto * descriptor = ODBCDescriptor::Of (handle);
547+ diagnostics = &descriptor->GetDiagnostics ();
548+ break ;
549+ }
550+
551+ case SQL_HANDLE_STMT: {
552+ auto * statement = ODBCStatement::Of (handle);
553+ diagnostics = &statement->GetDiagnostics ();
554+ break ;
555+ }
556+
557+ default :
558+ return SQL_INVALID_HANDLE;
559+ }
560+
561+ if (!diagnostics) {
562+ return SQL_ERROR;
563+ }
564+
565+ // Convert from ODBC 1 based record number to internal diagnostics 0 indexed storage
566+ const size_t record_index = static_cast <size_t >(rec_number - 1 );
567+ if (!diagnostics->HasRecord (record_index)) {
568+ return SQL_NO_DATA;
569+ }
570+
571+ if (sql_state) {
572+ // The length of the sql state is always 5 characters plus null
573+ SQLSMALLINT size = 6 ;
574+ const std::string& state = diagnostics->GetSQLState (record_index);
575+
576+ // Microsoft documentation does not mention
577+ // any SQLGetDiagRec return value that is associated with `sql_state` buffer, so
578+ // the return value for writing `sql_state` buffer is ignored by the driver.
579+ ARROW_UNUSED (GetStringAttribute (is_unicode, state, false , sql_state, size, &size,
580+ *diagnostics));
581+ }
582+
583+ if (native_error_ptr) {
584+ *native_error_ptr = diagnostics->GetNativeError (record_index);
585+ }
586+
587+ if (message_text || text_length_ptr) {
588+ const std::string& message = diagnostics->GetMessageText (record_index);
589+
590+ // According to Microsoft documentation,
591+ // SQL_SUCCESS_WITH_INFO should be returned if `*message_text` buffer was too
592+ // small to hold the requested diagnostic message.
593+ return GetStringAttribute (is_unicode, message, false , message_text, buffer_length,
594+ text_length_ptr, *diagnostics);
595+ }
596+
597+ return SQL_SUCCESS;
249598}
250599
251600SQLRETURN SQLGetEnvAttr (SQLHENV env, SQLINTEGER attr, SQLPOINTER value_ptr,
0 commit comments