Skip to content

Commit d92a1a8

Browse files
alinaliBQjusting-bq
andcommitted
Extract SQLColAttribute implementation
Co-Authored-By: alinalibq <[email protected]> Co-Authored-By: justing-bq <[email protected]>
1 parent b10386e commit d92a1a8

File tree

7 files changed

+1396
-52
lines changed

7 files changed

+1396
-52
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,8 +1210,90 @@ SQLRETURN SQLColAttribute(SQLHSTMT stmt, SQLUSMALLINT record_number,
12101210
<< ", output_length: " << static_cast<const void*>(output_length)
12111211
<< ", numeric_attribute_ptr: "
12121212
<< static_cast<const void*>(numeric_attribute_ptr);
1213-
// GH-47721 TODO: Implement SQLColAttribute, pre-requisite requires SQLColumns
1214-
return SQL_INVALID_HANDLE;
1213+
1214+
using ODBC::ODBCDescriptor;
1215+
using ODBC::ODBCStatement;
1216+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
1217+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
1218+
ODBCDescriptor* ird = statement->GetIRD();
1219+
SQLINTEGER output_length_int;
1220+
switch (field_identifier) {
1221+
// Numeric attributes
1222+
// internal is SQLLEN, no conversion is needed
1223+
case SQL_DESC_DISPLAY_SIZE:
1224+
case SQL_DESC_OCTET_LENGTH: {
1225+
ird->GetField(record_number, field_identifier, numeric_attribute_ptr,
1226+
buffer_length, &output_length_int);
1227+
break;
1228+
}
1229+
// internal is SQLULEN, conversion is needed.
1230+
case SQL_COLUMN_LENGTH: // ODBC 2.0
1231+
case SQL_DESC_LENGTH: {
1232+
SQLULEN temp;
1233+
ird->GetField(record_number, field_identifier, &temp, buffer_length,
1234+
&output_length_int);
1235+
if (numeric_attribute_ptr) {
1236+
*numeric_attribute_ptr = static_cast<SQLLEN>(temp);
1237+
}
1238+
break;
1239+
}
1240+
// internal is SQLINTEGER, conversion is needed.
1241+
case SQL_DESC_AUTO_UNIQUE_VALUE:
1242+
case SQL_DESC_CASE_SENSITIVE:
1243+
case SQL_DESC_NUM_PREC_RADIX: {
1244+
SQLINTEGER temp;
1245+
ird->GetField(record_number, field_identifier, &temp, buffer_length,
1246+
&output_length_int);
1247+
if (numeric_attribute_ptr) {
1248+
*numeric_attribute_ptr = static_cast<SQLLEN>(temp);
1249+
}
1250+
break;
1251+
}
1252+
// internal is SQLSMALLINT, conversion is needed.
1253+
case SQL_DESC_CONCISE_TYPE:
1254+
case SQL_DESC_COUNT:
1255+
case SQL_DESC_FIXED_PREC_SCALE:
1256+
case SQL_DESC_TYPE:
1257+
case SQL_DESC_NULLABLE:
1258+
case SQL_COLUMN_PRECISION: // ODBC 2.0
1259+
case SQL_DESC_PRECISION:
1260+
case SQL_COLUMN_SCALE: // ODBC 2.0
1261+
case SQL_DESC_SCALE:
1262+
case SQL_DESC_SEARCHABLE:
1263+
case SQL_DESC_UNNAMED:
1264+
case SQL_DESC_UNSIGNED:
1265+
case SQL_DESC_UPDATABLE: {
1266+
SQLSMALLINT temp;
1267+
ird->GetField(record_number, field_identifier, &temp, buffer_length,
1268+
&output_length_int);
1269+
if (numeric_attribute_ptr) {
1270+
*numeric_attribute_ptr = static_cast<SQLLEN>(temp);
1271+
}
1272+
break;
1273+
}
1274+
// Character attributes
1275+
case SQL_DESC_BASE_COLUMN_NAME:
1276+
case SQL_DESC_BASE_TABLE_NAME:
1277+
case SQL_DESC_CATALOG_NAME:
1278+
case SQL_DESC_LABEL:
1279+
case SQL_DESC_LITERAL_PREFIX:
1280+
case SQL_DESC_LITERAL_SUFFIX:
1281+
case SQL_DESC_LOCAL_TYPE_NAME:
1282+
case SQL_DESC_NAME:
1283+
case SQL_DESC_SCHEMA_NAME:
1284+
case SQL_DESC_TABLE_NAME:
1285+
case SQL_DESC_TYPE_NAME:
1286+
ird->GetField(record_number, field_identifier, character_attribute_ptr,
1287+
buffer_length, &output_length_int);
1288+
break;
1289+
default:
1290+
throw DriverException("Invalid descriptor field", "HY091");
1291+
}
1292+
if (output_length) {
1293+
*output_length = static_cast<SQLSMALLINT>(output_length_int);
1294+
}
1295+
return SQL_SUCCESS;
1296+
});
12151297
}
12161298

12171299
SQLRETURN SQLGetTypeInfo(SQLHSTMT stmt, SQLSMALLINT data_type) {

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.cc

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include "arrow/flight/sql/column_metadata.h"
2121
#include "arrow/flight/sql/odbc/odbc_impl/platform.h"
2222
#include "arrow/flight/sql/odbc/odbc_impl/util.h"
23+
#include "arrow/type_traits.h"
24+
#include "arrow/util/key_value_metadata.h"
2325

2426
#include <utility>
2527
#include "arrow/flight/sql/odbc/odbc_impl/exceptions.h"
@@ -40,12 +42,8 @@ constexpr int32_t DefaultDecimalPrecision = 38;
4042
constexpr int32_t DefaultLengthForVariableLengthColumns = 1024;
4143

4244
namespace {
43-
std::shared_ptr<const KeyValueMetadata> empty_metadata_map(new KeyValueMetadata);
44-
4545
inline ColumnMetadata GetMetadata(const std::shared_ptr<Field>& field) {
46-
const auto& metadata_map = field->metadata();
47-
48-
ColumnMetadata metadata(metadata_map ? metadata_map : empty_metadata_map);
46+
ColumnMetadata metadata(field->metadata());
4947
return metadata;
5048
}
5149

@@ -207,10 +205,14 @@ size_t FlightSqlResultSetMetadata::GetOctetLength(int column_position) {
207205
.value_or(DefaultLengthForVariableLengthColumns);
208206
}
209207

210-
std::string FlightSqlResultSetMetadata::GetTypeName(int column_position) {
208+
std::string FlightSqlResultSetMetadata::GetTypeName(int column_position,
209+
int16_t data_type) {
211210
ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1));
212211

213-
return metadata.GetTypeName().ValueOrElse([] { return ""; });
212+
return metadata.GetTypeName().ValueOrElse([data_type] {
213+
// If we get an empty type name, figure out the type name from the data_type.
214+
return util::GetTypeNameFromSqlDataType(data_type);
215+
});
214216
}
215217

216218
Updatability FlightSqlResultSetMetadata::GetUpdatable(int column_position) {
@@ -239,20 +241,14 @@ Searchability FlightSqlResultSetMetadata::IsSearchable(int column_position) {
239241

240242
bool FlightSqlResultSetMetadata::IsUnsigned(int column_position) {
241243
const std::shared_ptr<Field>& field = schema_->field(column_position - 1);
242-
243-
switch (field->type()->id()) {
244-
case Type::UINT8:
245-
case Type::UINT16:
246-
case Type::UINT32:
247-
case Type::UINT64:
248-
return true;
249-
default:
250-
return false;
251-
}
244+
arrow::Type::type type_id = field->type()->id();
245+
// non-decimal and non-numeric types are unsigned.
246+
return !arrow::is_decimal(type_id) &&
247+
(!arrow::is_numeric(type_id) || arrow::is_unsigned_integer(type_id));
252248
}
253249

254250
bool FlightSqlResultSetMetadata::IsFixedPrecScale(int column_position) {
255-
// TODO: Flight SQL column metadata does not have this, should we add to the spec?
251+
// Precision for Arrow data types are modifiable by the user
256252
return false;
257253
}
258254

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class FlightSqlResultSetMetadata : public ResultSetMetadata {
7777

7878
size_t GetOctetLength(int column_position) override;
7979

80-
std::string GetTypeName(int column_position) override;
80+
std::string GetTypeName(int column_position, int16_t data_type) override;
8181

8282
Updatability GetUpdatable(int column_position) override;
8383

@@ -87,6 +87,7 @@ class FlightSqlResultSetMetadata : public ResultSetMetadata {
8787

8888
Searchability IsSearchable(int column_position) override;
8989

90+
/// \brief Return true if the column is unsigned or not numeric
9091
bool IsUnsigned(int column_position) override;
9192

9293
bool IsFixedPrecScale(int column_position) override;

cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_descriptor.cc

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@ void ODBCDescriptor::GetHeaderField(SQLSMALLINT field_identifier, SQLPOINTER val
276276
GetAttribute(rows_processed_ptr_, value, buffer_length, output_length);
277277
break;
278278
case SQL_DESC_COUNT: {
279-
GetAttribute(highest_one_based_bound_record_, value, buffer_length, output_length);
279+
// highest_one_based_bound_record_ equals number of records + 1
280+
GetAttribute(static_cast<SQLSMALLINT>(highest_one_based_bound_record_ - 1), value,
281+
buffer_length, output_length);
280282
break;
281283
}
282284
default:
@@ -310,54 +312,55 @@ void ODBCDescriptor::GetField(SQLSMALLINT record_number, SQLSMALLINT field_ident
310312
throw DriverException("Invalid descriptor index", "07009");
311313
}
312314

313-
// TODO: Restrict fields based on AppDescriptor IPD, and IRD.
315+
// GH-47867 TODO: Restrict fields based on AppDescriptor IPD, and IRD.
314316

317+
bool length_in_bytes = true;
315318
SQLSMALLINT zero_based_record = record_number - 1;
316319
const DescriptorRecord& record = records_[zero_based_record];
317320
switch (field_identifier) {
318321
case SQL_DESC_BASE_COLUMN_NAME:
319-
GetAttributeUTF8(record.base_column_name, value, buffer_length, output_length,
320-
GetDiagnostics());
322+
GetAttributeSQLWCHAR(record.base_column_name, length_in_bytes, value, buffer_length,
323+
output_length, GetDiagnostics());
321324
break;
322325
case SQL_DESC_BASE_TABLE_NAME:
323-
GetAttributeUTF8(record.base_table_name, value, buffer_length, output_length,
324-
GetDiagnostics());
326+
GetAttributeSQLWCHAR(record.base_table_name, length_in_bytes, value, buffer_length,
327+
output_length, GetDiagnostics());
325328
break;
326329
case SQL_DESC_CATALOG_NAME:
327-
GetAttributeUTF8(record.catalog_name, value, buffer_length, output_length,
328-
GetDiagnostics());
330+
GetAttributeSQLWCHAR(record.catalog_name, length_in_bytes, value, buffer_length,
331+
output_length, GetDiagnostics());
329332
break;
330333
case SQL_DESC_LABEL:
331-
GetAttributeUTF8(record.label, value, buffer_length, output_length,
332-
GetDiagnostics());
334+
GetAttributeSQLWCHAR(record.label, length_in_bytes, value, buffer_length,
335+
output_length, GetDiagnostics());
333336
break;
334337
case SQL_DESC_LITERAL_PREFIX:
335-
GetAttributeUTF8(record.literal_prefix, value, buffer_length, output_length,
336-
GetDiagnostics());
338+
GetAttributeSQLWCHAR(record.literal_prefix, length_in_bytes, value, buffer_length,
339+
output_length, GetDiagnostics());
337340
break;
338341
case SQL_DESC_LITERAL_SUFFIX:
339-
GetAttributeUTF8(record.literal_suffix, value, buffer_length, output_length,
340-
GetDiagnostics());
342+
GetAttributeSQLWCHAR(record.literal_suffix, length_in_bytes, value, buffer_length,
343+
output_length, GetDiagnostics());
341344
break;
342345
case SQL_DESC_LOCAL_TYPE_NAME:
343-
GetAttributeUTF8(record.local_type_name, value, buffer_length, output_length,
344-
GetDiagnostics());
346+
GetAttributeSQLWCHAR(record.local_type_name, length_in_bytes, value, buffer_length,
347+
output_length, GetDiagnostics());
345348
break;
346349
case SQL_DESC_NAME:
347-
GetAttributeUTF8(record.name, value, buffer_length, output_length,
348-
GetDiagnostics());
350+
GetAttributeSQLWCHAR(record.name, length_in_bytes, value, buffer_length,
351+
output_length, GetDiagnostics());
349352
break;
350353
case SQL_DESC_SCHEMA_NAME:
351-
GetAttributeUTF8(record.schema_name, value, buffer_length, output_length,
352-
GetDiagnostics());
354+
GetAttributeSQLWCHAR(record.schema_name, length_in_bytes, value, buffer_length,
355+
output_length, GetDiagnostics());
353356
break;
354357
case SQL_DESC_TABLE_NAME:
355-
GetAttributeUTF8(record.table_name, value, buffer_length, output_length,
356-
GetDiagnostics());
358+
GetAttributeSQLWCHAR(record.table_name, length_in_bytes, value, buffer_length,
359+
output_length, GetDiagnostics());
357360
break;
358361
case SQL_DESC_TYPE_NAME:
359-
GetAttributeUTF8(record.type_name, value, buffer_length, output_length,
360-
GetDiagnostics());
362+
GetAttributeSQLWCHAR(record.type_name, length_in_bytes, value, buffer_length,
363+
output_length, GetDiagnostics());
361364
break;
362365

363366
case SQL_DESC_DATA_PTR:
@@ -367,7 +370,7 @@ void ODBCDescriptor::GetField(SQLSMALLINT record_number, SQLSMALLINT field_ident
367370
case SQL_DESC_OCTET_LENGTH_PTR:
368371
GetAttribute(record.indicator_ptr, value, buffer_length, output_length);
369372
break;
370-
373+
case SQL_COLUMN_LENGTH: // ODBC 2.0
371374
case SQL_DESC_LENGTH:
372375
GetAttribute(record.length, value, buffer_length, output_length);
373376
break;
@@ -407,12 +410,14 @@ void ODBCDescriptor::GetField(SQLSMALLINT record_number, SQLSMALLINT field_ident
407410
case SQL_DESC_PARAMETER_TYPE:
408411
GetAttribute(record.param_type, value, buffer_length, output_length);
409412
break;
413+
case SQL_COLUMN_PRECISION: // ODBC 2.0
410414
case SQL_DESC_PRECISION:
411415
GetAttribute(record.precision, value, buffer_length, output_length);
412416
break;
413417
case SQL_DESC_ROWVER:
414418
GetAttribute(record.row_ver, value, buffer_length, output_length);
415419
break;
420+
case SQL_COLUMN_SCALE: // ODBC 2.0
416421
case SQL_DESC_SCALE:
417422
GetAttribute(record.scale, value, buffer_length, output_length);
418423
break;
@@ -479,6 +484,8 @@ void ODBCDescriptor::PopulateFromResultSetMetadata(ResultSetMetadata* rsmd) {
479484

480485
for (size_t i = 0; i < records_.size(); ++i) {
481486
size_t one_based_index = i + 1;
487+
int16_t concise_type = rsmd->GetConciseType(one_based_index);
488+
482489
records_[i].base_column_name = rsmd->GetBaseColumnName(one_based_index);
483490
records_[i].base_table_name = rsmd->GetBaseTableName(one_based_index);
484491
records_[i].catalog_name = rsmd->GetCatalogName(one_based_index);
@@ -489,9 +496,8 @@ void ODBCDescriptor::PopulateFromResultSetMetadata(ResultSetMetadata* rsmd) {
489496
records_[i].name = rsmd->GetName(one_based_index);
490497
records_[i].schema_name = rsmd->GetSchemaName(one_based_index);
491498
records_[i].table_name = rsmd->GetTableName(one_based_index);
492-
records_[i].type_name = rsmd->GetTypeName(one_based_index);
493-
records_[i].concise_type = GetSqlTypeForODBCVersion(
494-
rsmd->GetConciseType(one_based_index), is_2x_connection_);
499+
records_[i].type_name = rsmd->GetTypeName(one_based_index, concise_type);
500+
records_[i].concise_type = GetSqlTypeForODBCVersion(concise_type, is_2x_connection_);
495501
records_[i].data_ptr = nullptr;
496502
records_[i].indicator_ptr = nullptr;
497503
records_[i].display_size = rsmd->GetColumnDisplaySize(one_based_index);

cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/result_set_metadata.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
#pragma once
1919

20-
#include "arrow/flight/sql/odbc/odbc_impl/types.h"
21-
2220
#include <string>
21+
#include "arrow/flight/sql/odbc/odbc_impl/types.h"
2322

2423
namespace arrow::flight::sql::odbc {
2524

@@ -143,8 +142,9 @@ class ResultSetMetadata {
143142

144143
/// \brief It returns the data type as a string.
145144
/// \param column_position [in] the position of the column, starting from 1.
145+
/// \param data_type [in] the data type of the column.
146146
/// \return the data type string.
147-
virtual std::string GetTypeName(int column_position) = 0;
147+
virtual std::string GetTypeName(int column_position, int16_t data_type) = 0;
148148

149149
/// \brief It returns a numeric values indicate the updatability of the
150150
/// column.

cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_arrow_test(flight_sql_odbc_test
3434
SOURCES
3535
odbc_test_suite.cc
3636
odbc_test_suite.h
37+
columns_test.cc
3738
connection_attr_test.cc
3839
connection_test.cc
3940
statement_attr_test.cc

0 commit comments

Comments
 (0)