Skip to content

Commit d4230d9

Browse files
committed
feat: add adbc_schema
1 parent 693da06 commit d4230d9

File tree

5 files changed

+366
-0
lines changed

5 files changed

+366
-0
lines changed

CLAUDE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ The extension provides the following functions:
2727
- `adbc_tables(handle)` - Returns list of tables in the database.
2828
- `adbc_table_types(handle)` - Returns supported table types (e.g., "table", "view").
2929
- `adbc_columns(handle, [table_name := ...])` - Returns column metadata (name, type, ordinal position, nullability).
30+
- `adbc_schema(handle, table_name)` - Returns the Arrow schema for a specific table (field names, Arrow types, nullability).
3031

3132
### Example Usage
3233

@@ -49,6 +50,7 @@ SELECT * FROM adbc_info(getvariable('conn')::BIGINT);
4950
SELECT * FROM adbc_tables(getvariable('conn')::BIGINT);
5051
SELECT * FROM adbc_table_types(getvariable('conn')::BIGINT);
5152
SELECT * FROM adbc_columns(getvariable('conn')::BIGINT, table_name := 'test');
53+
SELECT * FROM adbc_schema(getvariable('conn')::BIGINT, 'test');
5254

5355
-- Disconnect
5456
SELECT adbc_disconnect(getvariable('conn')::BIGINT);

docs/README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,45 @@ Output:
317317
└──────────────┴─────────────┴────────────┴─────────────┴──────────────────┴─────────┴───────────┴─────────────┘
318318
```
319319

320+
### adbc_schema
321+
322+
Returns the Arrow schema for a specific table, showing field names, Arrow data types, and nullability.
323+
324+
```sql
325+
adbc_schema(connection_id, table_name, [catalog:=], [schema:=]) -> TABLE
326+
```
327+
328+
**Parameters:**
329+
- `connection_id`: Connection handle from `adbc_connect`
330+
- `table_name`: Name of the table to get the schema for
331+
- `catalog` (optional): Catalog containing the table
332+
- `schema` (optional): Database schema containing the table
333+
334+
**Returns:** A table with columns:
335+
- `field_name`: Name of the field/column
336+
- `field_type`: Arrow data type (e.g., "int64", "utf8", "float64", "timestamp[us]")
337+
- `nullable`: Whether the field allows NULL values
338+
339+
**Example:**
340+
341+
```sql
342+
SELECT * FROM adbc_schema(getvariable('conn')::BIGINT, 'users');
343+
```
344+
345+
Output:
346+
```
347+
┌────────────┬────────────┬──────────┐
348+
│ field_name │ field_type │ nullable │
349+
├────────────┼────────────┼──────────┤
350+
│ id │ int64 │ true │
351+
│ name │ utf8 │ true │
352+
│ email │ utf8 │ true │
353+
│ created_at │ timestamp │ true │
354+
└────────────┴────────────┴──────────┘
355+
```
356+
357+
**Note:** The `field_type` shows Arrow types, which may differ from the SQL types defined in the table. The mapping depends on the ADBC driver implementation.
358+
320359
## ADBC Drivers
321360

322361
ADBC drivers are available for many databases. Here are some common ones:

src/adbc_catalog.cpp

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,285 @@ static void AdbcColumnsFunction(ClientContext &context, TableFunctionInput &data
945945
output.SetCardinality(count);
946946
}
947947

948+
//===--------------------------------------------------------------------===//
949+
// adbc_schema - Get Arrow schema for a specific table
950+
//===--------------------------------------------------------------------===//
951+
952+
// Structure to hold a schema field row
953+
struct SchemaFieldRow {
954+
string field_name;
955+
string field_type;
956+
bool nullable;
957+
};
958+
959+
struct AdbcSchemaBindData : public TableFunctionData {
960+
int64_t connection_id;
961+
shared_ptr<AdbcConnectionWrapper> connection;
962+
string table_name;
963+
string catalog_filter;
964+
string schema_filter;
965+
bool has_catalog_filter = false;
966+
bool has_schema_filter = false;
967+
};
968+
969+
struct AdbcSchemaGlobalState : public GlobalTableFunctionState {
970+
// Extracted schema fields
971+
vector<SchemaFieldRow> field_rows;
972+
idx_t current_row = 0;
973+
974+
idx_t MaxThreads() const override {
975+
return 1;
976+
}
977+
};
978+
979+
// Helper to convert Arrow format string to human-readable type name
980+
static string ArrowFormatToTypeName(const char *format) {
981+
if (!format) return "unknown";
982+
983+
// Handle basic types - see Arrow C Data Interface spec
984+
switch (format[0]) {
985+
case 'n': return "null";
986+
case 'b': return "boolean";
987+
case 'c': return "int8";
988+
case 'C': return "uint8";
989+
case 's': return "int16";
990+
case 'S': return "uint16";
991+
case 'i': return "int32";
992+
case 'I': return "uint32";
993+
case 'l': return "int64";
994+
case 'L': return "uint64";
995+
case 'e': return "float16";
996+
case 'f': return "float32";
997+
case 'g': return "float64";
998+
case 'z': return "binary";
999+
case 'Z': return "large_binary";
1000+
case 'u': return "utf8";
1001+
case 'U': return "large_utf8";
1002+
case 'd': {
1003+
// Decimal: d:precision,scale or d:precision,scale,bitwidth
1004+
return "decimal" + string(format + 1);
1005+
}
1006+
case 'w': {
1007+
// Fixed-width binary: w:bytewidth
1008+
return "fixed_binary" + string(format + 1);
1009+
}
1010+
case 't': {
1011+
// Temporal types
1012+
if (strlen(format) < 2) return "temporal";
1013+
switch (format[1]) {
1014+
case 'd': {
1015+
// Date: tdD (days) or tdm (milliseconds)
1016+
if (strlen(format) >= 3 && format[2] == 'D') return "date32";
1017+
if (strlen(format) >= 3 && format[2] == 'm') return "date64";
1018+
return "date";
1019+
}
1020+
case 't': {
1021+
// Time: tt[smun] (seconds/millis/micros/nanos)
1022+
if (strlen(format) >= 3) {
1023+
switch (format[2]) {
1024+
case 's': return "time32[s]";
1025+
case 'm': return "time32[ms]";
1026+
case 'u': return "time64[us]";
1027+
case 'n': return "time64[ns]";
1028+
}
1029+
}
1030+
return "time";
1031+
}
1032+
case 's': {
1033+
// Timestamp: ts[smun]:timezone
1034+
string result = "timestamp";
1035+
if (strlen(format) >= 3) {
1036+
switch (format[2]) {
1037+
case 's': result += "[s]"; break;
1038+
case 'm': result += "[ms]"; break;
1039+
case 'u': result += "[us]"; break;
1040+
case 'n': result += "[ns]"; break;
1041+
}
1042+
}
1043+
// Include timezone if present
1044+
const char *tz = strchr(format, ':');
1045+
if (tz && strlen(tz) > 1) {
1046+
result += " tz=" + string(tz + 1);
1047+
}
1048+
return result;
1049+
}
1050+
case 'D': {
1051+
// Duration: tD[smun]
1052+
if (strlen(format) >= 3) {
1053+
switch (format[2]) {
1054+
case 's': return "duration[s]";
1055+
case 'm': return "duration[ms]";
1056+
case 'u': return "duration[us]";
1057+
case 'n': return "duration[ns]";
1058+
}
1059+
}
1060+
return "duration";
1061+
}
1062+
case 'i': {
1063+
// Interval: tiM (months), tiD (days/time), tin (month/day/nano)
1064+
if (strlen(format) >= 3) {
1065+
switch (format[2]) {
1066+
case 'M': return "interval[months]";
1067+
case 'D': return "interval[days]";
1068+
case 'n': return "interval[month_day_nano]";
1069+
}
1070+
}
1071+
return "interval";
1072+
}
1073+
}
1074+
return "temporal";
1075+
}
1076+
case '+': {
1077+
// Nested types
1078+
if (strlen(format) < 2) return "nested";
1079+
switch (format[1]) {
1080+
case 'l': return "list";
1081+
case 'L': return "large_list";
1082+
case 'w': return "fixed_list" + string(format + 2);
1083+
case 's': return "struct";
1084+
case 'm': return "map";
1085+
case 'u': {
1086+
// Union: +ud:type_ids or +us:type_ids
1087+
if (strlen(format) >= 3) {
1088+
if (format[2] == 'd') return "dense_union";
1089+
if (format[2] == 's') return "sparse_union";
1090+
}
1091+
return "union";
1092+
}
1093+
case 'r': return "run_end_encoded";
1094+
case 'v': {
1095+
// List view types
1096+
if (strlen(format) >= 3) {
1097+
if (format[2] == 'l') return "list_view";
1098+
if (format[2] == 'L') return "large_list_view";
1099+
}
1100+
return "list_view";
1101+
}
1102+
}
1103+
return "nested";
1104+
}
1105+
default:
1106+
// Return format string directly for unknown types
1107+
return string(format);
1108+
}
1109+
}
1110+
1111+
// Helper to extract fields from an ArrowSchema
1112+
static void ExtractSchemaFields(ArrowSchema *schema, vector<SchemaFieldRow> &field_rows) {
1113+
if (!schema) return;
1114+
1115+
for (int64_t i = 0; i < schema->n_children; i++) {
1116+
ArrowSchema *child = schema->children[i];
1117+
if (!child) continue;
1118+
1119+
SchemaFieldRow row;
1120+
row.field_name = child->name ? child->name : "";
1121+
row.field_type = ArrowFormatToTypeName(child->format);
1122+
// In Arrow C Data Interface, nullable is indicated by absence of ARROW_FLAG_NULLABLE bit NOT being set
1123+
// flags & 2 means nullable (ARROW_FLAG_NULLABLE = 2)
1124+
row.nullable = (child->flags & 2) != 0;
1125+
field_rows.push_back(row);
1126+
}
1127+
}
1128+
1129+
static unique_ptr<FunctionData> AdbcSchemaBind(ClientContext &context, TableFunctionBindInput &input,
1130+
vector<LogicalType> &return_types, vector<string> &names) {
1131+
auto bind_data = make_uniq<AdbcSchemaBindData>();
1132+
1133+
bind_data->connection_id = input.inputs[0].GetValue<int64_t>();
1134+
bind_data->table_name = input.inputs[1].GetValue<string>();
1135+
1136+
// Check for optional filter parameters
1137+
auto catalog_it = input.named_parameters.find("catalog");
1138+
if (catalog_it != input.named_parameters.end() && !catalog_it->second.IsNull()) {
1139+
bind_data->catalog_filter = catalog_it->second.GetValue<string>();
1140+
bind_data->has_catalog_filter = true;
1141+
}
1142+
1143+
auto schema_it = input.named_parameters.find("schema");
1144+
if (schema_it != input.named_parameters.end() && !schema_it->second.IsNull()) {
1145+
bind_data->schema_filter = schema_it->second.GetValue<string>();
1146+
bind_data->has_schema_filter = true;
1147+
}
1148+
1149+
auto &registry = ConnectionRegistry::Get();
1150+
bind_data->connection = registry.Get(bind_data->connection_id);
1151+
if (!bind_data->connection) {
1152+
throw InvalidInputException("adbc_schema: Invalid connection handle: " + to_string(bind_data->connection_id));
1153+
}
1154+
1155+
if (!bind_data->connection->IsInitialized()) {
1156+
throw InvalidInputException("adbc_schema: Connection has been closed");
1157+
}
1158+
1159+
// Return schema for fields
1160+
names = {"field_name", "field_type", "nullable"};
1161+
return_types = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::BOOLEAN};
1162+
1163+
return std::move(bind_data);
1164+
}
1165+
1166+
static unique_ptr<GlobalTableFunctionState> AdbcSchemaInitGlobal(ClientContext &context, TableFunctionInitInput &input) {
1167+
auto &bind_data = input.bind_data->Cast<AdbcSchemaBindData>();
1168+
auto global_state = make_uniq<AdbcSchemaGlobalState>();
1169+
1170+
if (!bind_data.connection->IsInitialized()) {
1171+
throw InvalidInputException("adbc_schema: Connection has been closed");
1172+
}
1173+
1174+
const char *catalog = bind_data.has_catalog_filter ? bind_data.catalog_filter.c_str() : nullptr;
1175+
const char *db_schema = bind_data.has_schema_filter ? bind_data.schema_filter.c_str() : nullptr;
1176+
1177+
ArrowSchema schema;
1178+
memset(&schema, 0, sizeof(schema));
1179+
1180+
try {
1181+
bind_data.connection->GetTableSchema(catalog, db_schema, bind_data.table_name.c_str(), &schema);
1182+
} catch (Exception &e) {
1183+
throw IOException("adbc_schema: Failed to get table schema: " + string(e.what()));
1184+
}
1185+
1186+
// Extract fields from the schema
1187+
ExtractSchemaFields(&schema, global_state->field_rows);
1188+
1189+
// Release the schema
1190+
if (schema.release) {
1191+
schema.release(&schema);
1192+
}
1193+
1194+
return std::move(global_state);
1195+
}
1196+
1197+
static unique_ptr<LocalTableFunctionState> AdbcSchemaInitLocal(ExecutionContext &context, TableFunctionInitInput &input,
1198+
GlobalTableFunctionState *global_state_p) {
1199+
return nullptr;
1200+
}
1201+
1202+
static void AdbcSchemaFunction(ClientContext &context, TableFunctionInput &data, DataChunk &output) {
1203+
auto &global_state = data.global_state->Cast<AdbcSchemaGlobalState>();
1204+
1205+
if (global_state.current_row >= global_state.field_rows.size()) {
1206+
output.SetCardinality(0);
1207+
return;
1208+
}
1209+
1210+
idx_t count = 0;
1211+
auto &name_vector = output.data[0];
1212+
auto &type_vector = output.data[1];
1213+
auto &nullable_vector = output.data[2];
1214+
1215+
while (global_state.current_row < global_state.field_rows.size() && count < STANDARD_VECTOR_SIZE) {
1216+
auto &row = global_state.field_rows[global_state.current_row];
1217+
name_vector.SetValue(count, Value(row.field_name));
1218+
type_vector.SetValue(count, Value(row.field_type));
1219+
nullable_vector.SetValue(count, Value(row.nullable));
1220+
count++;
1221+
global_state.current_row++;
1222+
}
1223+
1224+
output.SetCardinality(count);
1225+
}
1226+
9481227
//===--------------------------------------------------------------------===//
9491228
// Register all catalog functions
9501229
//===--------------------------------------------------------------------===//
@@ -982,6 +1261,14 @@ void RegisterAdbcCatalogFunctions(DatabaseInstance &db) {
9821261
adbc_columns_function.named_parameters["column_name"] = LogicalType::VARCHAR;
9831262
adbc_columns_function.projection_pushdown = false;
9841263
loader.RegisterFunction(adbc_columns_function);
1264+
1265+
// adbc_schema(connection_id, table_name, ...) - Get Arrow schema for a table
1266+
TableFunction adbc_schema_function("adbc_schema", {LogicalType::BIGINT, LogicalType::VARCHAR}, AdbcSchemaFunction,
1267+
AdbcSchemaBind, AdbcSchemaInitGlobal, AdbcSchemaInitLocal);
1268+
adbc_schema_function.named_parameters["catalog"] = LogicalType::VARCHAR;
1269+
adbc_schema_function.named_parameters["schema"] = LogicalType::VARCHAR;
1270+
adbc_schema_function.projection_pushdown = false;
1271+
loader.RegisterFunction(adbc_schema_function);
9851272
}
9861273

9871274
} // namespace adbc

src/include/adbc_connection.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,15 @@ class AdbcConnectionWrapper {
156156
CheckAdbc(status, error.Get(), "Failed to get table types");
157157
}
158158

159+
// Get the Arrow schema for a specific table
160+
void GetTableSchema(const char *catalog, const char *db_schema,
161+
const char *table_name, ArrowSchema *schema) {
162+
AdbcErrorGuard error;
163+
auto status = AdbcConnectionGetTableSchema(&connection, catalog, db_schema,
164+
table_name, schema, error.Get());
165+
CheckAdbc(status, error.Get(), "Failed to get table schema");
166+
}
167+
159168
// Non-copyable
160169
AdbcConnectionWrapper(const AdbcConnectionWrapper &) = delete;
161170
AdbcConnectionWrapper &operator=(const AdbcConnectionWrapper &) = delete;

0 commit comments

Comments
 (0)