@@ -65,6 +65,20 @@ SQLRETURN SQLFreeStmt(SQLHSTMT handle, SQLUSMALLINT option) {
6565 return SQL_INVALID_HANDLE;
6666}
6767
68+ inline bool IsValidStringFieldArgs (SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length,
69+ SQLSMALLINT* string_length_ptr, bool is_unicode) {
70+ const SQLSMALLINT char_size = is_unicode ? GetSqlWCharSize () : sizeof (char );
71+ const bool has_valid_buffer =
72+ diag_info_ptr && buffer_length >= 0 && buffer_length % char_size == 0 ;
73+
74+ // regardless of capacity return false if invalid
75+ if (diag_info_ptr && !has_valid_buffer) {
76+ return false ;
77+ }
78+
79+ return has_valid_buffer || string_length_ptr;
80+ }
81+
6882SQLRETURN SQLGetDiagField (SQLSMALLINT handle_type, SQLHANDLE handle,
6983 SQLSMALLINT rec_number, SQLSMALLINT diag_identifier,
7084 SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length,
@@ -76,8 +90,259 @@ SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle,
7690 << " , diag_info_ptr: " << diag_info_ptr
7791 << " , buffer_length: " << buffer_length << " , string_length_ptr: "
7892 << static_cast <const void *>(string_length_ptr);
79- // GH-46575 TODO: Implement SQLGetDiagField
80- return SQL_INVALID_HANDLE;
93+ // GH-46575 TODO: Add tests for SQLGetDiagField
94+ using arrow::flight::sql::odbc::Diagnostics;
95+ using ODBC::GetStringAttribute;
96+ using ODBC::ODBCConnection;
97+ using ODBC::ODBCDescriptor;
98+ using ODBC::ODBCEnvironment;
99+ using ODBC::ODBCStatement;
100+
101+ if (!handle) {
102+ return SQL_INVALID_HANDLE;
103+ }
104+
105+ if (!diag_info_ptr && !string_length_ptr) {
106+ return SQL_ERROR;
107+ }
108+
109+ // If buffer length derived from null terminated string
110+ if (diag_info_ptr && buffer_length == SQL_NTS) {
111+ const wchar_t * str = reinterpret_cast <wchar_t *>(diag_info_ptr);
112+ buffer_length = wcslen (str) * arrow::flight::sql::odbc::GetSqlWCharSize ();
113+ }
114+
115+ // Set character type to be Unicode by default
116+ const bool is_unicode = true ;
117+ Diagnostics* diagnostics = nullptr ;
118+
119+ switch (handle_type) {
120+ case SQL_HANDLE_ENV: {
121+ ODBCEnvironment* environment = reinterpret_cast <ODBCEnvironment*>(handle);
122+ diagnostics = &environment->GetDiagnostics ();
123+ break ;
124+ }
125+
126+ case SQL_HANDLE_DBC: {
127+ ODBCConnection* connection = reinterpret_cast <ODBCConnection*>(handle);
128+ diagnostics = &connection->GetDiagnostics ();
129+ break ;
130+ }
131+
132+ case SQL_HANDLE_DESC: {
133+ ODBCDescriptor* descriptor = reinterpret_cast <ODBCDescriptor*>(handle);
134+ diagnostics = &descriptor->GetDiagnostics ();
135+ break ;
136+ }
137+
138+ case SQL_HANDLE_STMT: {
139+ ODBCStatement* statement = reinterpret_cast <ODBCStatement*>(handle);
140+ diagnostics = &statement->GetDiagnostics ();
141+ break ;
142+ }
143+
144+ default :
145+ return SQL_ERROR;
146+ }
147+
148+ if (!diagnostics) {
149+ return SQL_ERROR;
150+ }
151+
152+ // Retrieve and return if header level diagnostics
153+ switch (diag_identifier) {
154+ case SQL_DIAG_NUMBER: {
155+ if (diag_info_ptr) {
156+ *static_cast <SQLINTEGER*>(diag_info_ptr) =
157+ static_cast <SQLINTEGER>(diagnostics->GetRecordCount ());
158+ }
159+
160+ if (string_length_ptr) {
161+ *string_length_ptr = sizeof (SQLINTEGER);
162+ }
163+
164+ return SQL_SUCCESS;
165+ }
166+
167+ // TODO implement return code function
168+ case SQL_DIAG_RETURNCODE: {
169+ return SQL_SUCCESS;
170+ }
171+
172+ case SQL_DIAG_CURSOR_ROW_COUNT: {
173+ if (handle_type == SQL_HANDLE_STMT) {
174+ if (diag_info_ptr) {
175+ // Will always be 0 if only SELECT supported
176+ *static_cast <SQLLEN*>(diag_info_ptr) = 0 ;
177+ }
178+
179+ if (string_length_ptr) {
180+ *string_length_ptr = sizeof (SQLLEN);
181+ }
182+
183+ return SQL_SUCCESS;
184+ }
185+
186+ return SQL_ERROR;
187+ }
188+
189+ // Not supported
190+ case SQL_DIAG_DYNAMIC_FUNCTION:
191+ case SQL_DIAG_DYNAMIC_FUNCTION_CODE: {
192+ if (handle_type == SQL_HANDLE_STMT) {
193+ return SQL_SUCCESS;
194+ }
195+
196+ return SQL_ERROR;
197+ }
198+
199+ case SQL_DIAG_ROW_COUNT: {
200+ if (handle_type == SQL_HANDLE_STMT) {
201+ if (diag_info_ptr) {
202+ // Will always be 0 if only SELECT is supported
203+ *static_cast <SQLLEN*>(diag_info_ptr) = 0 ;
204+ }
205+
206+ if (string_length_ptr) {
207+ *string_length_ptr = sizeof (SQLLEN);
208+ }
209+
210+ return SQL_SUCCESS;
211+ }
212+
213+ return SQL_ERROR;
214+ }
215+ }
216+
217+ // If not a diagnostic header field then the record number must be 1 or greater
218+ if (rec_number < 1 ) {
219+ return SQL_ERROR;
220+ }
221+
222+ // Retrieve record level diagnostics from specified 1 based record
223+ const uint32_t record_index = static_cast <uint32_t >(rec_number - 1 );
224+ if (!diagnostics->HasRecord (record_index)) {
225+ return SQL_NO_DATA;
226+ }
227+
228+ // Retrieve record field data
229+ switch (diag_identifier) {
230+ case SQL_DIAG_MESSAGE_TEXT: {
231+ if (IsValidStringFieldArgs (diag_info_ptr, buffer_length, string_length_ptr,
232+ is_unicode)) {
233+ const std::string& message = diagnostics->GetMessageText (record_index);
234+ return GetStringAttribute (is_unicode, message, true , diag_info_ptr, buffer_length,
235+ string_length_ptr, *diagnostics);
236+ }
237+
238+ return SQL_ERROR;
239+ }
240+
241+ case SQL_DIAG_NATIVE: {
242+ if (diag_info_ptr) {
243+ *static_cast <SQLINTEGER*>(diag_info_ptr) =
244+ diagnostics->GetNativeError (record_index);
245+ }
246+
247+ if (string_length_ptr) {
248+ *string_length_ptr = sizeof (SQLINTEGER);
249+ }
250+
251+ return SQL_SUCCESS;
252+ }
253+
254+ case SQL_DIAG_SERVER_NAME: {
255+ if (IsValidStringFieldArgs (diag_info_ptr, buffer_length, string_length_ptr,
256+ is_unicode)) {
257+ switch (handle_type) {
258+ case SQL_HANDLE_DBC: {
259+ ODBCConnection* connection = reinterpret_cast <ODBCConnection*>(handle);
260+ std::string dsn = connection->GetDSN ();
261+ return GetStringAttribute (is_unicode, dsn, true , diag_info_ptr, buffer_length,
262+ string_length_ptr, *diagnostics);
263+ }
264+
265+ case SQL_HANDLE_DESC: {
266+ ODBCDescriptor* descriptor = reinterpret_cast <ODBCDescriptor*>(handle);
267+ ODBCConnection* connection = &descriptor->GetConnection ();
268+ std::string dsn = connection->GetDSN ();
269+ return GetStringAttribute (is_unicode, dsn, true , diag_info_ptr, buffer_length,
270+ string_length_ptr, *diagnostics);
271+ break ;
272+ }
273+
274+ case SQL_HANDLE_STMT: {
275+ ODBCStatement* statement = reinterpret_cast <ODBCStatement*>(handle);
276+ ODBCConnection* connection = &statement->GetConnection ();
277+ std::string dsn = connection->GetDSN ();
278+ return GetStringAttribute (is_unicode, dsn, true , diag_info_ptr, buffer_length,
279+ string_length_ptr, *diagnostics);
280+ }
281+
282+ default :
283+ return SQL_ERROR;
284+ }
285+ }
286+
287+ return SQL_ERROR;
288+ }
289+
290+ case SQL_DIAG_SQLSTATE: {
291+ if (IsValidStringFieldArgs (diag_info_ptr, buffer_length, string_length_ptr,
292+ is_unicode)) {
293+ const std::string& state = diagnostics->GetSQLState (record_index);
294+ return GetStringAttribute (is_unicode, state, true , diag_info_ptr, buffer_length,
295+ string_length_ptr, *diagnostics);
296+ }
297+
298+ return SQL_ERROR;
299+ }
300+
301+ // Return valid dummy variable for unimplemented field
302+ case SQL_DIAG_COLUMN_NUMBER: {
303+ if (diag_info_ptr) {
304+ *static_cast <SQLINTEGER*>(diag_info_ptr) = SQL_NO_COLUMN_NUMBER;
305+ }
306+
307+ if (string_length_ptr) {
308+ *string_length_ptr = sizeof (SQLINTEGER);
309+ }
310+
311+ return SQL_SUCCESS;
312+ }
313+
314+ // Return empty string dummy variable for unimplemented fields
315+ case SQL_DIAG_CLASS_ORIGIN:
316+ case SQL_DIAG_CONNECTION_NAME:
317+ case SQL_DIAG_SUBCLASS_ORIGIN: {
318+ if (IsValidStringFieldArgs (diag_info_ptr, buffer_length, string_length_ptr,
319+ is_unicode)) {
320+ return GetStringAttribute (is_unicode, " " , true , diag_info_ptr, buffer_length,
321+ string_length_ptr, *diagnostics);
322+ }
323+
324+ return SQL_ERROR;
325+ }
326+
327+ // Return valid dummy variable for unimplemented field
328+ case SQL_DIAG_ROW_NUMBER: {
329+ if (diag_info_ptr) {
330+ *static_cast <SQLLEN*>(diag_info_ptr) = SQL_NO_ROW_NUMBER;
331+ }
332+
333+ if (string_length_ptr) {
334+ *string_length_ptr = sizeof (SQLLEN);
335+ }
336+
337+ return SQL_SUCCESS;
338+ }
339+
340+ default : {
341+ return SQL_ERROR;
342+ }
343+ }
344+
345+ return SQL_ERROR;
81346}
82347
83348SQLRETURN SQLGetDiagRec (SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT rec_number,
@@ -91,8 +356,84 @@ SQLRETURN SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT r
91356 << " , message_text: " << static_cast <const void *>(message_text)
92357 << " , buffer_length: " << buffer_length
93358 << " , text_length_ptr: " << static_cast <const void *>(text_length_ptr);
94- // GH-46575 TODO: Implement SQLGetDiagRec
95- return SQL_INVALID_HANDLE;
359+ // GH-46575 TODO: Add tests for SQLGetDiagRec
360+ using arrow::flight::sql::odbc::Diagnostics;
361+ using ODBC::GetStringAttribute;
362+ using ODBC::ODBCConnection;
363+ using ODBC::ODBCDescriptor;
364+ using ODBC::ODBCEnvironment;
365+ using ODBC::ODBCStatement;
366+
367+ if (!handle) {
368+ return SQL_INVALID_HANDLE;
369+ }
370+
371+ // Record number must be greater or equal to 1
372+ if (rec_number < 1 || buffer_length < 0 ) {
373+ return SQL_ERROR;
374+ }
375+
376+ // Set character type to be Unicode by default
377+ const bool is_unicode = true ;
378+ Diagnostics* diagnostics = nullptr ;
379+
380+ switch (handle_type) {
381+ case SQL_HANDLE_ENV: {
382+ auto * environment = ODBCEnvironment::Of (handle);
383+ diagnostics = &environment->GetDiagnostics ();
384+ break ;
385+ }
386+
387+ case SQL_HANDLE_DBC: {
388+ auto * connection = ODBCConnection::Of (handle);
389+ diagnostics = &connection->GetDiagnostics ();
390+ break ;
391+ }
392+
393+ case SQL_HANDLE_DESC: {
394+ auto * descriptor = ODBCDescriptor::Of (handle);
395+ diagnostics = &descriptor->GetDiagnostics ();
396+ break ;
397+ }
398+
399+ case SQL_HANDLE_STMT: {
400+ auto * statement = ODBCStatement::Of (handle);
401+ diagnostics = &statement->GetDiagnostics ();
402+ break ;
403+ }
404+
405+ default :
406+ return SQL_INVALID_HANDLE;
407+ }
408+
409+ if (!diagnostics) {
410+ return SQL_ERROR;
411+ }
412+
413+ // Convert from ODBC 1 based record number to internal diagnostics 0 indexed storage
414+ const size_t record_index = static_cast <size_t >(rec_number - 1 );
415+ if (!diagnostics->HasRecord (record_index)) {
416+ return SQL_NO_DATA;
417+ }
418+
419+ if (sql_state) {
420+ // The length of the sql state is always 5 characters plus null
421+ SQLSMALLINT size = 6 ;
422+ const std::string& state = diagnostics->GetSQLState (record_index);
423+ GetStringAttribute (is_unicode, state, false , sql_state, size, &size, *diagnostics);
424+ }
425+
426+ if (native_error_ptr) {
427+ *native_error_ptr = diagnostics->GetNativeError (record_index);
428+ }
429+
430+ if (message_text || text_length_ptr) {
431+ const std::string& message = diagnostics->GetMessageText (record_index);
432+ return GetStringAttribute (is_unicode, message, false , message_text, buffer_length,
433+ text_length_ptr, *diagnostics);
434+ }
435+
436+ return SQL_SUCCESS;
96437}
97438
98439SQLRETURN SQLGetEnvAttr (SQLHENV env, SQLINTEGER attr, SQLPOINTER value_ptr,
0 commit comments