@@ -945,6 +945,285 @@ static void AdbcColumnsFunction(ClientContext &context, TableFunctionInput &data
945945 output.SetCardinality (count);
946946}
947947
948+ // ===--------------------------------------------------------------------===//
949+ // adbc_schema - Get Arrow schema for a specific table
950+ // ===--------------------------------------------------------------------===//
951+
952+ // Structure to hold a schema field row
953+ struct SchemaFieldRow {
954+ string field_name;
955+ string field_type;
956+ bool nullable;
957+ };
958+
959+ struct AdbcSchemaBindData : public TableFunctionData {
960+ int64_t connection_id;
961+ shared_ptr<AdbcConnectionWrapper> connection;
962+ string table_name;
963+ string catalog_filter;
964+ string schema_filter;
965+ bool has_catalog_filter = false ;
966+ bool has_schema_filter = false ;
967+ };
968+
969+ struct AdbcSchemaGlobalState : public GlobalTableFunctionState {
970+ // Extracted schema fields
971+ vector<SchemaFieldRow> field_rows;
972+ idx_t current_row = 0 ;
973+
974+ idx_t MaxThreads () const override {
975+ return 1 ;
976+ }
977+ };
978+
979+ // Helper to convert Arrow format string to human-readable type name
980+ static string ArrowFormatToTypeName (const char *format) {
981+ if (!format) return " unknown" ;
982+
983+ // Handle basic types - see Arrow C Data Interface spec
984+ switch (format[0 ]) {
985+ case ' n' : return " null" ;
986+ case ' b' : return " boolean" ;
987+ case ' c' : return " int8" ;
988+ case ' C' : return " uint8" ;
989+ case ' s' : return " int16" ;
990+ case ' S' : return " uint16" ;
991+ case ' i' : return " int32" ;
992+ case ' I' : return " uint32" ;
993+ case ' l' : return " int64" ;
994+ case ' L' : return " uint64" ;
995+ case ' e' : return " float16" ;
996+ case ' f' : return " float32" ;
997+ case ' g' : return " float64" ;
998+ case ' z' : return " binary" ;
999+ case ' Z' : return " large_binary" ;
1000+ case ' u' : return " utf8" ;
1001+ case ' U' : return " large_utf8" ;
1002+ case ' d' : {
1003+ // Decimal: d:precision,scale or d:precision,scale,bitwidth
1004+ return " decimal" + string (format + 1 );
1005+ }
1006+ case ' w' : {
1007+ // Fixed-width binary: w:bytewidth
1008+ return " fixed_binary" + string (format + 1 );
1009+ }
1010+ case ' t' : {
1011+ // Temporal types
1012+ if (strlen (format) < 2 ) return " temporal" ;
1013+ switch (format[1 ]) {
1014+ case ' d' : {
1015+ // Date: tdD (days) or tdm (milliseconds)
1016+ if (strlen (format) >= 3 && format[2 ] == ' D' ) return " date32" ;
1017+ if (strlen (format) >= 3 && format[2 ] == ' m' ) return " date64" ;
1018+ return " date" ;
1019+ }
1020+ case ' t' : {
1021+ // Time: tt[smun] (seconds/millis/micros/nanos)
1022+ if (strlen (format) >= 3 ) {
1023+ switch (format[2 ]) {
1024+ case ' s' : return " time32[s]" ;
1025+ case ' m' : return " time32[ms]" ;
1026+ case ' u' : return " time64[us]" ;
1027+ case ' n' : return " time64[ns]" ;
1028+ }
1029+ }
1030+ return " time" ;
1031+ }
1032+ case ' s' : {
1033+ // Timestamp: ts[smun]:timezone
1034+ string result = " timestamp" ;
1035+ if (strlen (format) >= 3 ) {
1036+ switch (format[2 ]) {
1037+ case ' s' : result += " [s]" ; break ;
1038+ case ' m' : result += " [ms]" ; break ;
1039+ case ' u' : result += " [us]" ; break ;
1040+ case ' n' : result += " [ns]" ; break ;
1041+ }
1042+ }
1043+ // Include timezone if present
1044+ const char *tz = strchr (format, ' :' );
1045+ if (tz && strlen (tz) > 1 ) {
1046+ result += " tz=" + string (tz + 1 );
1047+ }
1048+ return result;
1049+ }
1050+ case ' D' : {
1051+ // Duration: tD[smun]
1052+ if (strlen (format) >= 3 ) {
1053+ switch (format[2 ]) {
1054+ case ' s' : return " duration[s]" ;
1055+ case ' m' : return " duration[ms]" ;
1056+ case ' u' : return " duration[us]" ;
1057+ case ' n' : return " duration[ns]" ;
1058+ }
1059+ }
1060+ return " duration" ;
1061+ }
1062+ case ' i' : {
1063+ // Interval: tiM (months), tiD (days/time), tin (month/day/nano)
1064+ if (strlen (format) >= 3 ) {
1065+ switch (format[2 ]) {
1066+ case ' M' : return " interval[months]" ;
1067+ case ' D' : return " interval[days]" ;
1068+ case ' n' : return " interval[month_day_nano]" ;
1069+ }
1070+ }
1071+ return " interval" ;
1072+ }
1073+ }
1074+ return " temporal" ;
1075+ }
1076+ case ' +' : {
1077+ // Nested types
1078+ if (strlen (format) < 2 ) return " nested" ;
1079+ switch (format[1 ]) {
1080+ case ' l' : return " list" ;
1081+ case ' L' : return " large_list" ;
1082+ case ' w' : return " fixed_list" + string (format + 2 );
1083+ case ' s' : return " struct" ;
1084+ case ' m' : return " map" ;
1085+ case ' u' : {
1086+ // Union: +ud:type_ids or +us:type_ids
1087+ if (strlen (format) >= 3 ) {
1088+ if (format[2 ] == ' d' ) return " dense_union" ;
1089+ if (format[2 ] == ' s' ) return " sparse_union" ;
1090+ }
1091+ return " union" ;
1092+ }
1093+ case ' r' : return " run_end_encoded" ;
1094+ case ' v' : {
1095+ // List view types
1096+ if (strlen (format) >= 3 ) {
1097+ if (format[2 ] == ' l' ) return " list_view" ;
1098+ if (format[2 ] == ' L' ) return " large_list_view" ;
1099+ }
1100+ return " list_view" ;
1101+ }
1102+ }
1103+ return " nested" ;
1104+ }
1105+ default :
1106+ // Return format string directly for unknown types
1107+ return string (format);
1108+ }
1109+ }
1110+
1111+ // Helper to extract fields from an ArrowSchema
1112+ static void ExtractSchemaFields (ArrowSchema *schema, vector<SchemaFieldRow> &field_rows) {
1113+ if (!schema) return ;
1114+
1115+ for (int64_t i = 0 ; i < schema->n_children ; i++) {
1116+ ArrowSchema *child = schema->children [i];
1117+ if (!child) continue ;
1118+
1119+ SchemaFieldRow row;
1120+ row.field_name = child->name ? child->name : " " ;
1121+ row.field_type = ArrowFormatToTypeName (child->format );
1122+ // In Arrow C Data Interface, nullable is indicated by absence of ARROW_FLAG_NULLABLE bit NOT being set
1123+ // flags & 2 means nullable (ARROW_FLAG_NULLABLE = 2)
1124+ row.nullable = (child->flags & 2 ) != 0 ;
1125+ field_rows.push_back (row);
1126+ }
1127+ }
1128+
1129+ static unique_ptr<FunctionData> AdbcSchemaBind (ClientContext &context, TableFunctionBindInput &input,
1130+ vector<LogicalType> &return_types, vector<string> &names) {
1131+ auto bind_data = make_uniq<AdbcSchemaBindData>();
1132+
1133+ bind_data->connection_id = input.inputs [0 ].GetValue <int64_t >();
1134+ bind_data->table_name = input.inputs [1 ].GetValue <string>();
1135+
1136+ // Check for optional filter parameters
1137+ auto catalog_it = input.named_parameters .find (" catalog" );
1138+ if (catalog_it != input.named_parameters .end () && !catalog_it->second .IsNull ()) {
1139+ bind_data->catalog_filter = catalog_it->second .GetValue <string>();
1140+ bind_data->has_catalog_filter = true ;
1141+ }
1142+
1143+ auto schema_it = input.named_parameters .find (" schema" );
1144+ if (schema_it != input.named_parameters .end () && !schema_it->second .IsNull ()) {
1145+ bind_data->schema_filter = schema_it->second .GetValue <string>();
1146+ bind_data->has_schema_filter = true ;
1147+ }
1148+
1149+ auto ®istry = ConnectionRegistry::Get ();
1150+ bind_data->connection = registry.Get (bind_data->connection_id );
1151+ if (!bind_data->connection ) {
1152+ throw InvalidInputException (" adbc_schema: Invalid connection handle: " + to_string (bind_data->connection_id ));
1153+ }
1154+
1155+ if (!bind_data->connection ->IsInitialized ()) {
1156+ throw InvalidInputException (" adbc_schema: Connection has been closed" );
1157+ }
1158+
1159+ // Return schema for fields
1160+ names = {" field_name" , " field_type" , " nullable" };
1161+ return_types = {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::BOOLEAN};
1162+
1163+ return std::move (bind_data);
1164+ }
1165+
1166+ static unique_ptr<GlobalTableFunctionState> AdbcSchemaInitGlobal (ClientContext &context, TableFunctionInitInput &input) {
1167+ auto &bind_data = input.bind_data ->Cast <AdbcSchemaBindData>();
1168+ auto global_state = make_uniq<AdbcSchemaGlobalState>();
1169+
1170+ if (!bind_data.connection ->IsInitialized ()) {
1171+ throw InvalidInputException (" adbc_schema: Connection has been closed" );
1172+ }
1173+
1174+ const char *catalog = bind_data.has_catalog_filter ? bind_data.catalog_filter .c_str () : nullptr ;
1175+ const char *db_schema = bind_data.has_schema_filter ? bind_data.schema_filter .c_str () : nullptr ;
1176+
1177+ ArrowSchema schema;
1178+ memset (&schema, 0 , sizeof (schema));
1179+
1180+ try {
1181+ bind_data.connection ->GetTableSchema (catalog, db_schema, bind_data.table_name .c_str (), &schema);
1182+ } catch (Exception &e) {
1183+ throw IOException (" adbc_schema: Failed to get table schema: " + string (e.what ()));
1184+ }
1185+
1186+ // Extract fields from the schema
1187+ ExtractSchemaFields (&schema, global_state->field_rows );
1188+
1189+ // Release the schema
1190+ if (schema.release ) {
1191+ schema.release (&schema);
1192+ }
1193+
1194+ return std::move (global_state);
1195+ }
1196+
1197+ static unique_ptr<LocalTableFunctionState> AdbcSchemaInitLocal (ExecutionContext &context, TableFunctionInitInput &input,
1198+ GlobalTableFunctionState *global_state_p) {
1199+ return nullptr ;
1200+ }
1201+
1202+ static void AdbcSchemaFunction (ClientContext &context, TableFunctionInput &data, DataChunk &output) {
1203+ auto &global_state = data.global_state ->Cast <AdbcSchemaGlobalState>();
1204+
1205+ if (global_state.current_row >= global_state.field_rows .size ()) {
1206+ output.SetCardinality (0 );
1207+ return ;
1208+ }
1209+
1210+ idx_t count = 0 ;
1211+ auto &name_vector = output.data [0 ];
1212+ auto &type_vector = output.data [1 ];
1213+ auto &nullable_vector = output.data [2 ];
1214+
1215+ while (global_state.current_row < global_state.field_rows .size () && count < STANDARD_VECTOR_SIZE) {
1216+ auto &row = global_state.field_rows [global_state.current_row ];
1217+ name_vector.SetValue (count, Value (row.field_name ));
1218+ type_vector.SetValue (count, Value (row.field_type ));
1219+ nullable_vector.SetValue (count, Value (row.nullable ));
1220+ count++;
1221+ global_state.current_row ++;
1222+ }
1223+
1224+ output.SetCardinality (count);
1225+ }
1226+
9481227// ===--------------------------------------------------------------------===//
9491228// Register all catalog functions
9501229// ===--------------------------------------------------------------------===//
@@ -982,6 +1261,14 @@ void RegisterAdbcCatalogFunctions(DatabaseInstance &db) {
9821261 adbc_columns_function.named_parameters [" column_name" ] = LogicalType::VARCHAR;
9831262 adbc_columns_function.projection_pushdown = false ;
9841263 loader.RegisterFunction (adbc_columns_function);
1264+
1265+ // adbc_schema(connection_id, table_name, ...) - Get Arrow schema for a table
1266+ TableFunction adbc_schema_function (" adbc_schema" , {LogicalType::BIGINT, LogicalType::VARCHAR}, AdbcSchemaFunction,
1267+ AdbcSchemaBind, AdbcSchemaInitGlobal, AdbcSchemaInitLocal);
1268+ adbc_schema_function.named_parameters [" catalog" ] = LogicalType::VARCHAR;
1269+ adbc_schema_function.named_parameters [" schema" ] = LogicalType::VARCHAR;
1270+ adbc_schema_function.projection_pushdown = false ;
1271+ loader.RegisterFunction (adbc_schema_function);
9851272}
9861273
9871274} // namespace adbc
0 commit comments