|
17 | 17 | * under the License. |
18 | 18 | */ |
19 | 19 |
|
| 20 | +#include <ranges> |
| 21 | + |
20 | 22 | #include <arrow/array/builder_binary.h> |
21 | 23 | #include <arrow/array/builder_decimal.h> |
22 | 24 | #include <arrow/array/builder_nested.h> |
23 | 25 | #include <arrow/array/builder_primitive.h> |
| 26 | +#include <arrow/extension_type.h> |
24 | 27 | #include <arrow/json/from_string.h> |
25 | 28 | #include <arrow/type.h> |
26 | 29 | #include <arrow/util/decimal.h> |
@@ -451,4 +454,231 @@ Status AppendDatumToBuilder(const ::avro::NodePtr& avro_node, |
451 | 454 | projected_schema, array_builder); |
452 | 455 | } |
453 | 456 |
|
| 457 | +namespace { |
| 458 | + |
| 459 | +// ToAvroNodeVisitor uses 0 for null branch and 1 for value branch. |
| 460 | +constexpr int64_t kNullBranch = 0; |
| 461 | +constexpr int64_t kValueBranch = 1; |
| 462 | + |
| 463 | +} // namespace |
| 464 | + |
| 465 | +Status ExtractDatumFromArray(const ::arrow::Array& array, int64_t index, |
| 466 | + ::avro::GenericDatum* datum) { |
| 467 | + if (index < 0 || index >= array.length()) { |
| 468 | + return InvalidArgument("Cannot extract datum from array at index {} of length {}", |
| 469 | + index, array.length()); |
| 470 | + } |
| 471 | + |
| 472 | + if (array.IsNull(index)) { |
| 473 | + if (!datum->isUnion()) [[unlikely]] { |
| 474 | + return InvalidSchema("Cannot extract null to non-union type: {}", |
| 475 | + ::avro::toString(datum->type())); |
| 476 | + } |
| 477 | + datum->selectBranch(kNullBranch); |
| 478 | + return {}; |
| 479 | + } |
| 480 | + |
| 481 | + if (datum->isUnion()) { |
| 482 | + datum->selectBranch(kValueBranch); |
| 483 | + } |
| 484 | + |
| 485 | + switch (array.type()->id()) { |
| 486 | + case ::arrow::Type::BOOL: { |
| 487 | + const auto& bool_array = |
| 488 | + internal::checked_cast<const ::arrow::BooleanArray&>(array); |
| 489 | + datum->value<bool>() = bool_array.Value(index); |
| 490 | + return {}; |
| 491 | + } |
| 492 | + |
| 493 | + case ::arrow::Type::INT32: { |
| 494 | + const auto& int32_array = internal::checked_cast<const ::arrow::Int32Array&>(array); |
| 495 | + datum->value<int32_t>() = int32_array.Value(index); |
| 496 | + return {}; |
| 497 | + } |
| 498 | + |
| 499 | + case ::arrow::Type::INT64: { |
| 500 | + const auto& int64_array = internal::checked_cast<const ::arrow::Int64Array&>(array); |
| 501 | + datum->value<int64_t>() = int64_array.Value(index); |
| 502 | + return {}; |
| 503 | + } |
| 504 | + |
| 505 | + case ::arrow::Type::FLOAT: { |
| 506 | + const auto& float_array = internal::checked_cast<const ::arrow::FloatArray&>(array); |
| 507 | + datum->value<float>() = float_array.Value(index); |
| 508 | + return {}; |
| 509 | + } |
| 510 | + |
| 511 | + case ::arrow::Type::DOUBLE: { |
| 512 | + const auto& double_array = |
| 513 | + internal::checked_cast<const ::arrow::DoubleArray&>(array); |
| 514 | + datum->value<double>() = double_array.Value(index); |
| 515 | + return {}; |
| 516 | + } |
| 517 | + |
| 518 | + // TODO(gangwu): support LARGE_STRING. |
| 519 | + case ::arrow::Type::STRING: { |
| 520 | + const auto& string_array = |
| 521 | + internal::checked_cast<const ::arrow::StringArray&>(array); |
| 522 | + datum->value<std::string>() = string_array.GetString(index); |
| 523 | + return {}; |
| 524 | + } |
| 525 | + |
| 526 | + // TODO(gangwu): support LARGE_BINARY. |
| 527 | + case ::arrow::Type::BINARY: { |
| 528 | + const auto& binary_array = |
| 529 | + internal::checked_cast<const ::arrow::BinaryArray&>(array); |
| 530 | + std::string_view value = binary_array.GetView(index); |
| 531 | + datum->value<std::vector<uint8_t>>().assign( |
| 532 | + reinterpret_cast<const uint8_t*>(value.data()), |
| 533 | + reinterpret_cast<const uint8_t*>(value.data()) + value.size()); |
| 534 | + return {}; |
| 535 | + } |
| 536 | + |
| 537 | + case ::arrow::Type::FIXED_SIZE_BINARY: { |
| 538 | + const auto& fixed_array = |
| 539 | + internal::checked_cast<const ::arrow::FixedSizeBinaryArray&>(array); |
| 540 | + std::string_view value = fixed_array.GetView(index); |
| 541 | + auto& fixed_datum = datum->value<::avro::GenericFixed>(); |
| 542 | + fixed_datum.value().assign(value.begin(), value.end()); |
| 543 | + return {}; |
| 544 | + } |
| 545 | + |
| 546 | + case ::arrow::Type::DECIMAL128: { |
| 547 | + const auto& decimal_array = |
| 548 | + internal::checked_cast<const ::arrow::Decimal128Array&>(array); |
| 549 | + std::string_view decimal_value = decimal_array.GetView(index); |
| 550 | + auto& fixed_datum = datum->value<::avro::GenericFixed>(); |
| 551 | + auto& bytes = fixed_datum.value(); |
| 552 | + bytes.assign(decimal_value.begin(), decimal_value.end()); |
| 553 | + std::ranges::reverse(bytes); |
| 554 | + return {}; |
| 555 | + } |
| 556 | + |
| 557 | + case ::arrow::Type::DATE32: { |
| 558 | + const auto& date_array = internal::checked_cast<const ::arrow::Date32Array&>(array); |
| 559 | + datum->value<int32_t>() = date_array.Value(index); |
| 560 | + return {}; |
| 561 | + } |
| 562 | + |
| 563 | + case ::arrow::Type::TIME64: { |
| 564 | + const auto& time_array = internal::checked_cast<const ::arrow::Time64Array&>(array); |
| 565 | + datum->value<int64_t>() = time_array.Value(index); |
| 566 | + return {}; |
| 567 | + } |
| 568 | + |
| 569 | + // For both timestamp and timestamp_tz with time unit as microsecond. |
| 570 | + case ::arrow::Type::TIMESTAMP: { |
| 571 | + const auto& timestamp_array = |
| 572 | + internal::checked_cast<const ::arrow::TimestampArray&>(array); |
| 573 | + datum->value<int64_t>() = timestamp_array.Value(index); |
| 574 | + return {}; |
| 575 | + } |
| 576 | + |
| 577 | + case ::arrow::Type::EXTENSION: { |
| 578 | + if (array.type()->name() == "arrow.uuid") { |
| 579 | + const auto& extension_array = |
| 580 | + internal::checked_cast<const ::arrow::ExtensionArray&>(array); |
| 581 | + const auto& fixed_array = |
| 582 | + internal::checked_cast<const ::arrow::FixedSizeBinaryArray&>( |
| 583 | + *extension_array.storage()); |
| 584 | + std::string_view value = fixed_array.GetView(index); |
| 585 | + auto& fixed_datum = datum->value<::avro::GenericFixed>(); |
| 586 | + fixed_datum.value().assign(value.begin(), value.end()); |
| 587 | + return {}; |
| 588 | + } |
| 589 | + |
| 590 | + return NotSupported("Unsupported Arrow extension type: {}", array.type()->name()); |
| 591 | + } |
| 592 | + |
| 593 | + case ::arrow::Type::STRUCT: { |
| 594 | + const auto& struct_array = |
| 595 | + internal::checked_cast<const ::arrow::StructArray&>(array); |
| 596 | + auto& record = datum->value<::avro::GenericRecord>(); |
| 597 | + for (int i = 0; i < struct_array.num_fields(); ++i) { |
| 598 | + ICEBERG_RETURN_UNEXPECTED( |
| 599 | + ExtractDatumFromArray(*struct_array.field(i), index, &record.fieldAt(i))); |
| 600 | + } |
| 601 | + return {}; |
| 602 | + } |
| 603 | + |
| 604 | + // TODO(gangwu): support LARGE_LIST. |
| 605 | + case ::arrow::Type::LIST: { |
| 606 | + const auto& list_array = internal::checked_cast<const ::arrow::ListArray&>(array); |
| 607 | + auto& avro_array = datum->value<::avro::GenericArray>(); |
| 608 | + auto& elements = avro_array.value(); |
| 609 | + |
| 610 | + auto start = list_array.value_offset(index); |
| 611 | + auto end = list_array.value_offset(index + 1); |
| 612 | + auto length = end - start; |
| 613 | + |
| 614 | + auto values = list_array.values(); |
| 615 | + elements.resize(length, ::avro::GenericDatum(avro_array.schema()->leafAt(0))); |
| 616 | + |
| 617 | + for (int64_t i = 0; i < length; ++i) { |
| 618 | + ICEBERG_RETURN_UNEXPECTED( |
| 619 | + ExtractDatumFromArray(*values, start + i, &elements[i])); |
| 620 | + } |
| 621 | + return {}; |
| 622 | + } |
| 623 | + |
| 624 | + case ::arrow::Type::MAP: { |
| 625 | + const auto& map_array = internal::checked_cast<const ::arrow::MapArray&>(array); |
| 626 | + auto start = map_array.value_offset(index); |
| 627 | + auto end = map_array.value_offset(index + 1); |
| 628 | + auto length = end - start; |
| 629 | + |
| 630 | + auto keys = map_array.keys(); |
| 631 | + auto items = map_array.items(); |
| 632 | + |
| 633 | + if (datum->type() == ::avro::AVRO_MAP) { |
| 634 | + // Handle regular Avro map |
| 635 | + auto& avro_map = datum->value<::avro::GenericMap>(); |
| 636 | + auto value_node = avro_map.schema()->leafAt(1); |
| 637 | + |
| 638 | + auto& map_entries = avro_map.value(); |
| 639 | + map_entries.resize( |
| 640 | + length, std::make_pair(std::string(), ::avro::GenericDatum(value_node))); |
| 641 | + |
| 642 | + const auto& key_array = |
| 643 | + internal::checked_cast<const ::arrow::StringArray&>(*keys); |
| 644 | + |
| 645 | + for (int64_t i = 0; i < length; ++i) { |
| 646 | + auto& map_entry = map_entries[i]; |
| 647 | + map_entry.first = key_array.GetString(start + i); |
| 648 | + ICEBERG_RETURN_UNEXPECTED( |
| 649 | + ExtractDatumFromArray(*items, start + i, &map_entry.second)); |
| 650 | + } |
| 651 | + } else if (datum->type() == ::avro::AVRO_ARRAY) { |
| 652 | + // Handle array-based map (list<struct<key, value>>) |
| 653 | + auto& avro_array = datum->value<::avro::GenericArray>(); |
| 654 | + auto record_node = avro_array.schema()->leafAt(0); |
| 655 | + if (record_node->type() != ::avro::AVRO_RECORD || record_node->leaves() != 2) { |
| 656 | + return InvalidArgument( |
| 657 | + "Expected Avro record with 2 fields for map value, got: {}", |
| 658 | + ToString(record_node)); |
| 659 | + } |
| 660 | + |
| 661 | + auto& elements = avro_array.value(); |
| 662 | + elements.resize(length, ::avro::GenericDatum(record_node)); |
| 663 | + |
| 664 | + for (int64_t i = 0; i < length; ++i) { |
| 665 | + auto& record = elements[i].value<::avro::GenericRecord>(); |
| 666 | + ICEBERG_RETURN_UNEXPECTED( |
| 667 | + ExtractDatumFromArray(*keys, start + i, &record.fieldAt(0))); |
| 668 | + ICEBERG_RETURN_UNEXPECTED( |
| 669 | + ExtractDatumFromArray(*items, start + i, &record.fieldAt(1))); |
| 670 | + } |
| 671 | + } else { |
| 672 | + return InvalidArgument("Unsupported Avro type for map: {}", |
| 673 | + static_cast<int>(datum->type())); |
| 674 | + } |
| 675 | + return {}; |
| 676 | + } |
| 677 | + |
| 678 | + default: |
| 679 | + return InvalidArgument("Unsupported Arrow array type: {}", |
| 680 | + array.type()->ToString()); |
| 681 | + } |
| 682 | +} |
| 683 | + |
454 | 684 | } // namespace iceberg::avro |
0 commit comments