|
22 | 22 | #include <format> |
23 | 23 | #include <functional> |
24 | 24 |
|
| 25 | +#include "iceberg/schema_internal.h" |
25 | 26 | #include "iceberg/type.h" |
26 | 27 | #include "iceberg/util/formatter.h" // IWYU pragma: keep |
27 | 28 | #include "iceberg/util/macros.h" |
@@ -260,4 +261,148 @@ void NameToIdVisitor::Finish() { |
260 | 261 | } |
261 | 262 | } |
262 | 263 |
|
| 264 | +/// \brief Visitor for pruning columns based on selected field IDs. |
| 265 | +/// |
| 266 | +/// This visitor traverses a schema and creates a projected version containing only |
| 267 | +/// the specified fields. When `select_full_types` is true, a field with all its |
| 268 | +/// sub-fields are selected if its field-id has been selected; otherwise, only leaf |
| 269 | +/// fields of selected field-ids are selected. |
| 270 | +/// |
| 271 | +/// \note It returns an error when projection is not successful. |
| 272 | +class PruneColumnVisitor { |
| 273 | + public: |
| 274 | + PruneColumnVisitor(const std::unordered_set<int32_t>& selected_ids, |
| 275 | + bool select_full_types) |
| 276 | + : selected_ids_(selected_ids), select_full_types_(select_full_types) {} |
| 277 | + |
| 278 | + Result<std::shared_ptr<Type>> Visit(const std::shared_ptr<Type>& type) const { |
| 279 | + switch (type->type_id()) { |
| 280 | + case TypeId::kStruct: |
| 281 | + return Visit(internal::checked_pointer_cast<StructType>(type)); |
| 282 | + case TypeId::kList: |
| 283 | + return Visit(internal::checked_pointer_cast<ListType>(type)); |
| 284 | + case TypeId::kMap: |
| 285 | + return Visit(internal::checked_pointer_cast<MapType>(type)); |
| 286 | + default: |
| 287 | + return nullptr; |
| 288 | + } |
| 289 | + } |
| 290 | + |
| 291 | + Result<std::shared_ptr<Type>> Visit(const SchemaField& field) const { |
| 292 | + if (selected_ids_.contains(field.field_id())) { |
| 293 | + return (select_full_types_ || field.type()->is_primitive()) ? field.type() |
| 294 | + : Visit(field.type()); |
| 295 | + } |
| 296 | + return Visit(field.type()); |
| 297 | + } |
| 298 | + |
| 299 | + static SchemaField MakeField(const SchemaField& field, std::shared_ptr<Type> type) { |
| 300 | + return {field.field_id(), std::string(field.name()), std::move(type), |
| 301 | + field.optional(), std::string(field.doc())}; |
| 302 | + } |
| 303 | + |
| 304 | + Result<std::shared_ptr<Type>> Visit(const std::shared_ptr<StructType>& type) const { |
| 305 | + bool same_types = true; |
| 306 | + std::vector<SchemaField> selected_fields; |
| 307 | + for (const auto& field : type->fields()) { |
| 308 | + ICEBERG_ASSIGN_OR_RAISE(auto child_type, Visit(field)); |
| 309 | + if (child_type) { |
| 310 | + same_types = same_types && (child_type == field.type()); |
| 311 | + selected_fields.emplace_back(MakeField(field, std::move(child_type))); |
| 312 | + } |
| 313 | + } |
| 314 | + |
| 315 | + if (selected_fields.empty()) { |
| 316 | + return nullptr; |
| 317 | + } else if (same_types && selected_fields.size() == type->fields().size()) { |
| 318 | + return type; |
| 319 | + } |
| 320 | + return std::make_shared<StructType>(std::move(selected_fields)); |
| 321 | + } |
| 322 | + |
| 323 | + Result<std::shared_ptr<Type>> Visit(const std::shared_ptr<ListType>& type) const { |
| 324 | + const auto& elem_field = type->fields()[0]; |
| 325 | + ICEBERG_ASSIGN_OR_RAISE(auto elem_type, Visit(elem_field)); |
| 326 | + if (elem_type == nullptr) { |
| 327 | + return nullptr; |
| 328 | + } else if (elem_type == elem_field.type()) { |
| 329 | + return type; |
| 330 | + } |
| 331 | + return std::make_shared<ListType>(MakeField(elem_field, std::move(elem_type))); |
| 332 | + } |
| 333 | + |
| 334 | + Result<std::shared_ptr<Type>> Visit(const std::shared_ptr<MapType>& type) const { |
| 335 | + const auto& key_field = type->fields()[0]; |
| 336 | + const auto& value_field = type->fields()[1]; |
| 337 | + ICEBERG_ASSIGN_OR_RAISE(auto key_type, Visit(key_field)); |
| 338 | + ICEBERG_ASSIGN_OR_RAISE(auto value_type, Visit(value_field)); |
| 339 | + |
| 340 | + if (key_type == nullptr && value_type == nullptr) { |
| 341 | + return nullptr; |
| 342 | + } else if (value_type == value_field.type() && |
| 343 | + (key_type == key_field.type() || key_type == nullptr)) { |
| 344 | + return type; |
| 345 | + } else if (value_type == nullptr) { |
| 346 | + return InvalidArgument("Cannot project Map without value field"); |
| 347 | + } |
| 348 | + return std::make_shared<MapType>( |
| 349 | + (key_type == nullptr ? key_field : MakeField(key_field, std::move(key_type))), |
| 350 | + MakeField(value_field, std::move(value_type))); |
| 351 | + } |
| 352 | + |
| 353 | + private: |
| 354 | + const std::unordered_set<int32_t>& selected_ids_; |
| 355 | + const bool select_full_types_; |
| 356 | +}; |
| 357 | + |
| 358 | +Result<std::unique_ptr<Schema>> Schema::Select(std::span<const std::string> names, |
| 359 | + bool case_sensitive) const { |
| 360 | + const std::string kAllColumns = "*"; |
| 361 | + if (std::ranges::find(names, kAllColumns) != names.end()) { |
| 362 | + auto struct_type = ToStructType(*this); |
| 363 | + return FromStructType(std::move(*struct_type), std::nullopt); |
| 364 | + } |
| 365 | + |
| 366 | + std::unordered_set<int32_t> selected_ids; |
| 367 | + for (const auto& name : names) { |
| 368 | + ICEBERG_ASSIGN_OR_RAISE(auto result, FindFieldByName(name, case_sensitive)); |
| 369 | + if (result.has_value()) { |
| 370 | + selected_ids.insert(result.value().get().field_id()); |
| 371 | + } |
| 372 | + } |
| 373 | + |
| 374 | + PruneColumnVisitor visitor(selected_ids, /*select_full_types=*/true); |
| 375 | + ICEBERG_ASSIGN_OR_RAISE( |
| 376 | + auto pruned_type, visitor.Visit(std::shared_ptr<StructType>(ToStructType(*this)))); |
| 377 | + |
| 378 | + if (!pruned_type) { |
| 379 | + return std::make_unique<Schema>(std::vector<SchemaField>{}, std::nullopt); |
| 380 | + } |
| 381 | + |
| 382 | + if (pruned_type->type_id() != TypeId::kStruct) { |
| 383 | + return InvalidSchema("Projected type must be a struct type"); |
| 384 | + } |
| 385 | + |
| 386 | + return FromStructType(std::move(internal::checked_cast<StructType&>(*pruned_type)), |
| 387 | + std::nullopt); |
| 388 | +} |
| 389 | + |
| 390 | +Result<std::unique_ptr<Schema>> Schema::Project( |
| 391 | + const std::unordered_set<int32_t>& field_ids) const { |
| 392 | + PruneColumnVisitor visitor(field_ids, /*select_full_types=*/false); |
| 393 | + ICEBERG_ASSIGN_OR_RAISE( |
| 394 | + auto project_type, visitor.Visit(std::shared_ptr<StructType>(ToStructType(*this)))); |
| 395 | + |
| 396 | + if (!project_type) { |
| 397 | + return std::make_unique<Schema>(std::vector<SchemaField>{}, std::nullopt); |
| 398 | + } |
| 399 | + |
| 400 | + if (project_type->type_id() != TypeId::kStruct) { |
| 401 | + return InvalidSchema("Projected type must be a struct type"); |
| 402 | + } |
| 403 | + |
| 404 | + return FromStructType(std::move(internal::checked_cast<StructType&>(*project_type)), |
| 405 | + std::nullopt); |
| 406 | +} |
| 407 | + |
263 | 408 | } // namespace iceberg |
0 commit comments