|
21 | 21 |
|
22 | 22 | #include <format> |
23 | 23 |
|
24 | | -namespace iceberg { |
| 24 | +#include <nanoarrow/nanoarrow.hpp> |
| 25 | + |
| 26 | +#include "iceberg/transform/transform_function.h" |
| 27 | +#include "iceberg/transform/transform_spec.h" |
| 28 | +#include "iceberg/type.h" |
25 | 29 |
|
| 30 | +namespace iceberg { |
26 | 31 | namespace { |
27 | | -/// \brief Get the relative transform name |
28 | | -constexpr std::string_view ToString(TransformType type) { |
29 | | - switch (type) { |
30 | | - case TransformType::kUnknown: |
31 | | - return "unknown"; |
32 | | - case TransformType::kIdentity: |
33 | | - return "identity"; |
34 | | - case TransformType::kBucket: |
35 | | - return "bucket"; |
36 | | - case TransformType::kTruncate: |
37 | | - return "truncate"; |
38 | | - case TransformType::kYear: |
39 | | - return "year"; |
40 | | - case TransformType::kMonth: |
41 | | - return "month"; |
42 | | - case TransformType::kDay: |
43 | | - return "day"; |
44 | | - case TransformType::kHour: |
45 | | - return "hour"; |
46 | | - case TransformType::kVoid: |
47 | | - return "void"; |
48 | | - default: |
49 | | - return "invalid"; |
50 | | - } |
| 32 | + |
| 33 | +int32_t GetInt32FromParamArray(ArrowArray const& param_array) { |
| 34 | + ArrowArrayView view; |
| 35 | + ArrowArrayViewInitFromType(&view, NANOARROW_TYPE_INT32); |
| 36 | + NANOARROW_THROW_NOT_OK(ArrowArrayViewSetArray(&view, ¶m_array, nullptr)); |
| 37 | + const auto value = view.buffer_views[1].data.as_int32[0]; |
| 38 | + ArrowArrayViewReset(&view); |
| 39 | + return value; |
51 | 40 | } |
| 41 | + |
52 | 42 | } // namespace |
53 | 43 |
|
54 | | -TransformFunction::TransformFunction(TransformType type) : transform_type_(type) {} |
| 44 | +TransformFunction::TransformFunction(TransformType transform_type, |
| 45 | + std::shared_ptr<Type> source_type) |
| 46 | + : transform_type_(transform_type), source_type_(std::move(source_type)) {} |
55 | 47 |
|
56 | 48 | TransformType TransformFunction::transform_type() const { return transform_type_; } |
57 | 49 |
|
| 50 | +std::shared_ptr<Type> const& TransformFunction::source_type() const { |
| 51 | + return source_type_; |
| 52 | +} |
| 53 | + |
58 | 54 | std::string TransformFunction::ToString() const { |
59 | 55 | return std::format("{}", iceberg::ToString(transform_type_)); |
60 | 56 | } |
61 | 57 |
|
62 | 58 | bool TransformFunction::Equals(const TransformFunction& other) const { |
63 | | - return transform_type_ == other.transform_type_; |
| 59 | + return transform_type_ == other.transform_type_ && *source_type_ == *other.source_type_; |
64 | 60 | } |
65 | 61 |
|
66 | | -IdentityTransformFunction::IdentityTransformFunction() |
67 | | - : TransformFunction(TransformType::kIdentity) {} |
68 | | - |
69 | | -expected<ArrowArray, Error> IdentityTransformFunction::Transform( |
70 | | - const ArrowArray& input) { |
71 | | - return unexpected<Error>({.kind = ErrorKind::kNotSupported, |
72 | | - .message = "IdentityTransformFunction::Transform"}); |
| 62 | +expected<std::unique_ptr<TransformFunction>, Error> TransformFunction::Make( |
| 63 | + const TransformSpec& spec) { |
| 64 | + switch (spec.transform_type) { |
| 65 | + case TransformType::kIdentity: |
| 66 | + return std::make_unique<IdentityTransform>(spec.source_type); |
| 67 | + case TransformType::kBucket: { |
| 68 | + if (!spec.params_opt.has_value()) { |
| 69 | + return unexpected<Error>( |
| 70 | + {.kind = ErrorKind::kInvalidArgument, |
| 71 | + .message = "Bucket transform requires 1 parameter (number of buckets), but " |
| 72 | + "none were provided."}); |
| 73 | + } |
| 74 | + if (spec.params_opt->length != 1) { |
| 75 | + return unexpected<Error>( |
| 76 | + {.kind = ErrorKind::kInvalidArgument, |
| 77 | + .message = std::format("Bucket transform expects exactly 1 parameter " |
| 78 | + "(number of buckets), but got {}.", |
| 79 | + spec.params_opt->length)}); |
| 80 | + } |
| 81 | + auto num_buckets = GetInt32FromParamArray(spec.params_opt.value()); |
| 82 | + return std::make_unique<BucketTransform>(spec.source_type, num_buckets); |
| 83 | + } |
| 84 | + case TransformType::kTruncate: { |
| 85 | + if (!spec.params_opt.has_value()) { |
| 86 | + return unexpected<Error>({.kind = ErrorKind::kInvalidArgument, |
| 87 | + .message = "Truncate transform requires 1 parameter " |
| 88 | + "(width), but none were provided."}); |
| 89 | + } |
| 90 | + if (spec.params_opt->length != 1) { |
| 91 | + return unexpected<Error>( |
| 92 | + {.kind = ErrorKind::kInvalidArgument, |
| 93 | + .message = std::format( |
| 94 | + "Truncate transform expects exactly 1 parameter (width), but got {}.", |
| 95 | + spec.params_opt->length)}); |
| 96 | + } |
| 97 | + auto width = GetInt32FromParamArray(spec.params_opt.value()); |
| 98 | + return std::make_unique<TruncateTransform>(spec.source_type, width); |
| 99 | + } |
| 100 | + case TransformType::kYear: |
| 101 | + return std::make_unique<YearTransform>(spec.source_type); |
| 102 | + case TransformType::kMonth: |
| 103 | + return std::make_unique<MonthTransform>(spec.source_type); |
| 104 | + case TransformType::kDay: |
| 105 | + return std::make_unique<DayTransform>(spec.source_type); |
| 106 | + case TransformType::kHour: |
| 107 | + return std::make_unique<HourTransform>(spec.source_type); |
| 108 | + case TransformType::kVoid: |
| 109 | + return std::make_unique<VoidTransform>(spec.source_type); |
| 110 | + default: |
| 111 | + return unexpected<Error>( |
| 112 | + {.kind = ErrorKind::kInvalidArgument, |
| 113 | + .message = std::format("Unsupported or invalid transform type: {}", |
| 114 | + iceberg::ToString(spec.transform_type))}); |
| 115 | + } |
73 | 116 | } |
74 | 117 |
|
75 | 118 | } // namespace iceberg |
0 commit comments