Skip to content

Commit 0c8fddc

Browse files
alinaliBQrscales
andcommitted
Extract SQLGetDiagField and SQLGetDiagRec implementation
Co-Authored-By: rscales <[email protected]>
1 parent ac9d721 commit 0c8fddc

File tree

1 file changed

+345
-4
lines changed

1 file changed

+345
-4
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 345 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ SQLRETURN SQLFreeStmt(SQLHSTMT handle, SQLUSMALLINT option) {
6565
return SQL_INVALID_HANDLE;
6666
}
6767

68+
inline bool IsValidStringFieldArgs(SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length,
69+
SQLSMALLINT* string_length_ptr, bool is_unicode) {
70+
const SQLSMALLINT char_size = is_unicode ? GetSqlWCharSize() : sizeof(char);
71+
const bool has_valid_buffer =
72+
diag_info_ptr && buffer_length >= 0 && buffer_length % char_size == 0;
73+
74+
// regardless of capacity return false if invalid
75+
if (diag_info_ptr && !has_valid_buffer) {
76+
return false;
77+
}
78+
79+
return has_valid_buffer || string_length_ptr;
80+
}
81+
6882
SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle,
6983
SQLSMALLINT rec_number, SQLSMALLINT diag_identifier,
7084
SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length,
@@ -76,8 +90,259 @@ SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle,
7690
<< ", diag_info_ptr: " << diag_info_ptr
7791
<< ", buffer_length: " << buffer_length << ", string_length_ptr: "
7892
<< static_cast<const void*>(string_length_ptr);
79-
// GH-46575 TODO: Implement SQLGetDiagField
80-
return SQL_INVALID_HANDLE;
93+
// GH-46575 TODO: Add tests for SQLGetDiagField
94+
using arrow::flight::sql::odbc::Diagnostics;
95+
using ODBC::GetStringAttribute;
96+
using ODBC::ODBCConnection;
97+
using ODBC::ODBCDescriptor;
98+
using ODBC::ODBCEnvironment;
99+
using ODBC::ODBCStatement;
100+
101+
if (!handle) {
102+
return SQL_INVALID_HANDLE;
103+
}
104+
105+
if (!diag_info_ptr && !string_length_ptr) {
106+
return SQL_ERROR;
107+
}
108+
109+
// If buffer length derived from null terminated string
110+
if (diag_info_ptr && buffer_length == SQL_NTS) {
111+
const wchar_t* str = reinterpret_cast<wchar_t*>(diag_info_ptr);
112+
buffer_length = wcslen(str) * arrow::flight::sql::odbc::GetSqlWCharSize();
113+
}
114+
115+
// Set character type to be Unicode by default
116+
const bool is_unicode = true;
117+
Diagnostics* diagnostics = nullptr;
118+
119+
switch (handle_type) {
120+
case SQL_HANDLE_ENV: {
121+
ODBCEnvironment* environment = reinterpret_cast<ODBCEnvironment*>(handle);
122+
diagnostics = &environment->GetDiagnostics();
123+
break;
124+
}
125+
126+
case SQL_HANDLE_DBC: {
127+
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(handle);
128+
diagnostics = &connection->GetDiagnostics();
129+
break;
130+
}
131+
132+
case SQL_HANDLE_DESC: {
133+
ODBCDescriptor* descriptor = reinterpret_cast<ODBCDescriptor*>(handle);
134+
diagnostics = &descriptor->GetDiagnostics();
135+
break;
136+
}
137+
138+
case SQL_HANDLE_STMT: {
139+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(handle);
140+
diagnostics = &statement->GetDiagnostics();
141+
break;
142+
}
143+
144+
default:
145+
return SQL_ERROR;
146+
}
147+
148+
if (!diagnostics) {
149+
return SQL_ERROR;
150+
}
151+
152+
// Retrieve and return if header level diagnostics
153+
switch (diag_identifier) {
154+
case SQL_DIAG_NUMBER: {
155+
if (diag_info_ptr) {
156+
*static_cast<SQLINTEGER*>(diag_info_ptr) =
157+
static_cast<SQLINTEGER>(diagnostics->GetRecordCount());
158+
}
159+
160+
if (string_length_ptr) {
161+
*string_length_ptr = sizeof(SQLINTEGER);
162+
}
163+
164+
return SQL_SUCCESS;
165+
}
166+
167+
// TODO implement return code function
168+
case SQL_DIAG_RETURNCODE: {
169+
return SQL_SUCCESS;
170+
}
171+
172+
case SQL_DIAG_CURSOR_ROW_COUNT: {
173+
if (handle_type == SQL_HANDLE_STMT) {
174+
if (diag_info_ptr) {
175+
// Will always be 0 if only SELECT supported
176+
*static_cast<SQLLEN*>(diag_info_ptr) = 0;
177+
}
178+
179+
if (string_length_ptr) {
180+
*string_length_ptr = sizeof(SQLLEN);
181+
}
182+
183+
return SQL_SUCCESS;
184+
}
185+
186+
return SQL_ERROR;
187+
}
188+
189+
// Not supported
190+
case SQL_DIAG_DYNAMIC_FUNCTION:
191+
case SQL_DIAG_DYNAMIC_FUNCTION_CODE: {
192+
if (handle_type == SQL_HANDLE_STMT) {
193+
return SQL_SUCCESS;
194+
}
195+
196+
return SQL_ERROR;
197+
}
198+
199+
case SQL_DIAG_ROW_COUNT: {
200+
if (handle_type == SQL_HANDLE_STMT) {
201+
if (diag_info_ptr) {
202+
// Will always be 0 if only SELECT is supported
203+
*static_cast<SQLLEN*>(diag_info_ptr) = 0;
204+
}
205+
206+
if (string_length_ptr) {
207+
*string_length_ptr = sizeof(SQLLEN);
208+
}
209+
210+
return SQL_SUCCESS;
211+
}
212+
213+
return SQL_ERROR;
214+
}
215+
}
216+
217+
// If not a diagnostic header field then the record number must be 1 or greater
218+
if (rec_number < 1) {
219+
return SQL_ERROR;
220+
}
221+
222+
// Retrieve record level diagnostics from specified 1 based record
223+
const uint32_t record_index = static_cast<uint32_t>(rec_number - 1);
224+
if (!diagnostics->HasRecord(record_index)) {
225+
return SQL_NO_DATA;
226+
}
227+
228+
// Retrieve record field data
229+
switch (diag_identifier) {
230+
case SQL_DIAG_MESSAGE_TEXT: {
231+
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
232+
is_unicode)) {
233+
const std::string& message = diagnostics->GetMessageText(record_index);
234+
return GetStringAttribute(is_unicode, message, true, diag_info_ptr, buffer_length,
235+
string_length_ptr, *diagnostics);
236+
}
237+
238+
return SQL_ERROR;
239+
}
240+
241+
case SQL_DIAG_NATIVE: {
242+
if (diag_info_ptr) {
243+
*static_cast<SQLINTEGER*>(diag_info_ptr) =
244+
diagnostics->GetNativeError(record_index);
245+
}
246+
247+
if (string_length_ptr) {
248+
*string_length_ptr = sizeof(SQLINTEGER);
249+
}
250+
251+
return SQL_SUCCESS;
252+
}
253+
254+
case SQL_DIAG_SERVER_NAME: {
255+
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
256+
is_unicode)) {
257+
switch (handle_type) {
258+
case SQL_HANDLE_DBC: {
259+
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(handle);
260+
std::string dsn = connection->GetDSN();
261+
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length,
262+
string_length_ptr, *diagnostics);
263+
}
264+
265+
case SQL_HANDLE_DESC: {
266+
ODBCDescriptor* descriptor = reinterpret_cast<ODBCDescriptor*>(handle);
267+
ODBCConnection* connection = &descriptor->GetConnection();
268+
std::string dsn = connection->GetDSN();
269+
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length,
270+
string_length_ptr, *diagnostics);
271+
break;
272+
}
273+
274+
case SQL_HANDLE_STMT: {
275+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(handle);
276+
ODBCConnection* connection = &statement->GetConnection();
277+
std::string dsn = connection->GetDSN();
278+
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length,
279+
string_length_ptr, *diagnostics);
280+
}
281+
282+
default:
283+
return SQL_ERROR;
284+
}
285+
}
286+
287+
return SQL_ERROR;
288+
}
289+
290+
case SQL_DIAG_SQLSTATE: {
291+
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
292+
is_unicode)) {
293+
const std::string& state = diagnostics->GetSQLState(record_index);
294+
return GetStringAttribute(is_unicode, state, true, diag_info_ptr, buffer_length,
295+
string_length_ptr, *diagnostics);
296+
}
297+
298+
return SQL_ERROR;
299+
}
300+
301+
// Return valid dummy variable for unimplemented field
302+
case SQL_DIAG_COLUMN_NUMBER: {
303+
if (diag_info_ptr) {
304+
*static_cast<SQLINTEGER*>(diag_info_ptr) = SQL_NO_COLUMN_NUMBER;
305+
}
306+
307+
if (string_length_ptr) {
308+
*string_length_ptr = sizeof(SQLINTEGER);
309+
}
310+
311+
return SQL_SUCCESS;
312+
}
313+
314+
// Return empty string dummy variable for unimplemented fields
315+
case SQL_DIAG_CLASS_ORIGIN:
316+
case SQL_DIAG_CONNECTION_NAME:
317+
case SQL_DIAG_SUBCLASS_ORIGIN: {
318+
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
319+
is_unicode)) {
320+
return GetStringAttribute(is_unicode, "", true, diag_info_ptr, buffer_length,
321+
string_length_ptr, *diagnostics);
322+
}
323+
324+
return SQL_ERROR;
325+
}
326+
327+
// Return valid dummy variable for unimplemented field
328+
case SQL_DIAG_ROW_NUMBER: {
329+
if (diag_info_ptr) {
330+
*static_cast<SQLLEN*>(diag_info_ptr) = SQL_NO_ROW_NUMBER;
331+
}
332+
333+
if (string_length_ptr) {
334+
*string_length_ptr = sizeof(SQLLEN);
335+
}
336+
337+
return SQL_SUCCESS;
338+
}
339+
340+
default: {
341+
return SQL_ERROR;
342+
}
343+
}
344+
345+
return SQL_ERROR;
81346
}
82347

83348
SQLRETURN SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT rec_number,
@@ -91,8 +356,84 @@ SQLRETURN SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT r
91356
<< ", message_text: " << static_cast<const void*>(message_text)
92357
<< ", buffer_length: " << buffer_length
93358
<< ", text_length_ptr: " << static_cast<const void*>(text_length_ptr);
94-
// GH-46575 TODO: Implement SQLGetDiagRec
95-
return SQL_INVALID_HANDLE;
359+
// GH-46575 TODO: Add tests for SQLGetDiagRec
360+
using arrow::flight::sql::odbc::Diagnostics;
361+
using ODBC::GetStringAttribute;
362+
using ODBC::ODBCConnection;
363+
using ODBC::ODBCDescriptor;
364+
using ODBC::ODBCEnvironment;
365+
using ODBC::ODBCStatement;
366+
367+
if (!handle) {
368+
return SQL_INVALID_HANDLE;
369+
}
370+
371+
// Record number must be greater or equal to 1
372+
if (rec_number < 1 || buffer_length < 0) {
373+
return SQL_ERROR;
374+
}
375+
376+
// Set character type to be Unicode by default
377+
const bool is_unicode = true;
378+
Diagnostics* diagnostics = nullptr;
379+
380+
switch (handle_type) {
381+
case SQL_HANDLE_ENV: {
382+
auto* environment = ODBCEnvironment::Of(handle);
383+
diagnostics = &environment->GetDiagnostics();
384+
break;
385+
}
386+
387+
case SQL_HANDLE_DBC: {
388+
auto* connection = ODBCConnection::Of(handle);
389+
diagnostics = &connection->GetDiagnostics();
390+
break;
391+
}
392+
393+
case SQL_HANDLE_DESC: {
394+
auto* descriptor = ODBCDescriptor::Of(handle);
395+
diagnostics = &descriptor->GetDiagnostics();
396+
break;
397+
}
398+
399+
case SQL_HANDLE_STMT: {
400+
auto* statement = ODBCStatement::Of(handle);
401+
diagnostics = &statement->GetDiagnostics();
402+
break;
403+
}
404+
405+
default:
406+
return SQL_INVALID_HANDLE;
407+
}
408+
409+
if (!diagnostics) {
410+
return SQL_ERROR;
411+
}
412+
413+
// Convert from ODBC 1 based record number to internal diagnostics 0 indexed storage
414+
const size_t record_index = static_cast<size_t>(rec_number - 1);
415+
if (!diagnostics->HasRecord(record_index)) {
416+
return SQL_NO_DATA;
417+
}
418+
419+
if (sql_state) {
420+
// The length of the sql state is always 5 characters plus null
421+
SQLSMALLINT size = 6;
422+
const std::string& state = diagnostics->GetSQLState(record_index);
423+
GetStringAttribute(is_unicode, state, false, sql_state, size, &size, *diagnostics);
424+
}
425+
426+
if (native_error_ptr) {
427+
*native_error_ptr = diagnostics->GetNativeError(record_index);
428+
}
429+
430+
if (message_text || text_length_ptr) {
431+
const std::string& message = diagnostics->GetMessageText(record_index);
432+
return GetStringAttribute(is_unicode, message, false, message_text, buffer_length,
433+
text_length_ptr, *diagnostics);
434+
}
435+
436+
return SQL_SUCCESS;
96437
}
97438

98439
SQLRETURN SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER value_ptr,

0 commit comments

Comments
 (0)