diff --git a/src/iceberg/CMakeLists.txt b/src/iceberg/CMakeLists.txt index 9e0670ee8..9b634014b 100644 --- a/src/iceberg/CMakeLists.txt +++ b/src/iceberg/CMakeLists.txt @@ -31,6 +31,7 @@ set(ICEBERG_SOURCES statistics_file.cc table_metadata.cc transform.cc + transform_function.cc type.cc) set(ICEBERG_STATIC_BUILD_INTERFACE_LIBS) diff --git a/src/iceberg/json_internal.cc b/src/iceberg/json_internal.cc index 0905dfc6c..c6fe570c7 100644 --- a/src/iceberg/json_internal.cc +++ b/src/iceberg/json_internal.cc @@ -117,7 +117,7 @@ Result> SortFieldFromJson(const nlohmann::json& json) ICEBERG_ASSIGN_OR_RAISE(auto source_id, GetJsonValue(json, kSourceId)); ICEBERG_ASSIGN_OR_RAISE( auto transform, - GetJsonValue(json, kTransform).and_then(TransformFunctionFromString)); + GetJsonValue(json, kTransform).and_then(TransformFromString)); ICEBERG_ASSIGN_OR_RAISE( auto direction, GetJsonValue(json, kDirection).and_then(SortDirectionFromString)); @@ -401,7 +401,7 @@ Result> PartitionFieldFromJson( ICEBERG_ASSIGN_OR_RAISE(auto field_id, GetJsonValue(json, kFieldId)); ICEBERG_ASSIGN_OR_RAISE( auto transform, - GetJsonValue(json, kTransform).and_then(TransformFunctionFromString)); + GetJsonValue(json, kTransform).and_then(TransformFromString)); ICEBERG_ASSIGN_OR_RAISE(auto name, GetJsonValue(json, kName)); return std::make_unique(source_id, field_id, name, std::move(transform)); diff --git a/src/iceberg/partition_field.cc b/src/iceberg/partition_field.cc index ddfe8bada..59e6afcd8 100644 --- a/src/iceberg/partition_field.cc +++ b/src/iceberg/partition_field.cc @@ -28,7 +28,7 @@ namespace iceberg { PartitionField::PartitionField(int32_t source_id, int32_t field_id, std::string name, - std::shared_ptr transform) + std::shared_ptr transform) : source_id_(source_id), field_id_(field_id), name_(std::move(name)), @@ -40,9 +40,7 @@ int32_t PartitionField::field_id() const { return field_id_; } std::string_view PartitionField::name() const { return name_; } -std::shared_ptr const& PartitionField::transform() const { - return transform_; -} +std::shared_ptr const& PartitionField::transform() const { return transform_; } std::string PartitionField::ToString() const { return std::format("{} ({} {}({}))", name_, field_id_, *transform_, source_id_); diff --git a/src/iceberg/partition_field.h b/src/iceberg/partition_field.h index 3b75f21ce..e31be911b 100644 --- a/src/iceberg/partition_field.h +++ b/src/iceberg/partition_field.h @@ -43,7 +43,7 @@ class ICEBERG_EXPORT PartitionField : public util::Formattable { /// \param[in] name The partition field name. /// \param[in] transform The transform function. PartitionField(int32_t source_id, int32_t field_id, std::string name, - std::shared_ptr transform); + std::shared_ptr transform); /// \brief Get the source field ID. int32_t source_id() const; @@ -55,7 +55,7 @@ class ICEBERG_EXPORT PartitionField : public util::Formattable { std::string_view name() const; /// \brief Get the transform type. - std::shared_ptr const& transform() const; + std::shared_ptr const& transform() const; std::string ToString() const override; @@ -74,7 +74,7 @@ class ICEBERG_EXPORT PartitionField : public util::Formattable { int32_t source_id_; int32_t field_id_; std::string name_; - std::shared_ptr transform_; + std::shared_ptr transform_; }; } // namespace iceberg diff --git a/src/iceberg/sort_field.cc b/src/iceberg/sort_field.cc index ae5464b61..b96d01505 100644 --- a/src/iceberg/sort_field.cc +++ b/src/iceberg/sort_field.cc @@ -27,7 +27,7 @@ namespace iceberg { -SortField::SortField(int32_t source_id, std::shared_ptr transform, +SortField::SortField(int32_t source_id, std::shared_ptr transform, SortDirection direction, NullOrder null_order) : source_id_(source_id), transform_(std::move(transform)), @@ -36,9 +36,7 @@ SortField::SortField(int32_t source_id, std::shared_ptr trans int32_t SortField::source_id() const { return source_id_; } -std::shared_ptr const& SortField::transform() const { - return transform_; -} +std::shared_ptr const& SortField::transform() const { return transform_; } SortDirection SortField::direction() const { return direction_; } diff --git a/src/iceberg/sort_field.h b/src/iceberg/sort_field.h index 6037cec72..26879142b 100644 --- a/src/iceberg/sort_field.h +++ b/src/iceberg/sort_field.h @@ -97,14 +97,14 @@ class ICEBERG_EXPORT SortField : public util::Formattable { /// \param[in] transform The transform function. /// \param[in] direction The sort direction. /// \param[in] null_order The null order. - SortField(int32_t source_id, std::shared_ptr transform, + SortField(int32_t source_id, std::shared_ptr transform, SortDirection direction, NullOrder null_order); /// \brief Get the source field ID. int32_t source_id() const; /// \brief Get the transform type. - const std::shared_ptr& transform() const; + const std::shared_ptr& transform() const; /// \brief Get the sort direction. SortDirection direction() const; @@ -127,7 +127,7 @@ class ICEBERG_EXPORT SortField : public util::Formattable { [[nodiscard]] bool Equals(const SortField& other) const; int32_t source_id_; - std::shared_ptr transform_; + std::shared_ptr transform_; SortDirection direction_; NullOrder null_order_; }; diff --git a/src/iceberg/transform.cc b/src/iceberg/transform.cc index ed3708238..8ba12ce6b 100644 --- a/src/iceberg/transform.cc +++ b/src/iceberg/transform.cc @@ -20,66 +20,206 @@ #include "iceberg/transform.h" #include +#include -namespace iceberg { +#include "iceberg/transform_function.h" +#include "iceberg/type.h" +namespace iceberg { namespace { -/// \brief Get the relative transform name -constexpr std::string_view ToString(TransformType type) { +constexpr std::string_view kUnknownName = "unknown"; +constexpr std::string_view kIdentityName = "identity"; +constexpr std::string_view kBucketName = "bucket"; +constexpr std::string_view kTruncateName = "truncate"; +constexpr std::string_view kYearName = "year"; +constexpr std::string_view kMonthName = "month"; +constexpr std::string_view kDayName = "day"; +constexpr std::string_view kHourName = "hour"; +constexpr std::string_view kVoidName = "void"; +} // namespace + +constexpr std::string_view TransformTypeToString(TransformType type) { switch (type) { case TransformType::kUnknown: - return "unknown"; + return kUnknownName; case TransformType::kIdentity: - return "identity"; + return kIdentityName; case TransformType::kBucket: - return "bucket"; + return kBucketName; case TransformType::kTruncate: - return "truncate"; + return kTruncateName; case TransformType::kYear: - return "year"; + return kYearName; case TransformType::kMonth: - return "month"; + return kMonthName; case TransformType::kDay: - return "day"; + return kDayName; case TransformType::kHour: - return "hour"; + return kHourName; case TransformType::kVoid: - return "void"; - default: - return "invalid"; + return kVoidName; } } -} // namespace -TransformFunction::TransformFunction(TransformType type) : transform_type_(type) {} +std::shared_ptr Transform::Identity() { + static auto instance = + std::shared_ptr(new Transform(TransformType::kIdentity)); + return instance; +} + +std::shared_ptr Transform::Year() { + static auto instance = std::shared_ptr(new Transform(TransformType::kYear)); + return instance; +} -TransformType TransformFunction::transform_type() const { return transform_type_; } +std::shared_ptr Transform::Month() { + static auto instance = std::shared_ptr(new Transform(TransformType::kMonth)); + return instance; +} + +std::shared_ptr Transform::Day() { + static auto instance = std::shared_ptr(new Transform(TransformType::kDay)); + return instance; +} + +std::shared_ptr Transform::Hour() { + static auto instance = std::shared_ptr(new Transform(TransformType::kHour)); + return instance; +} + +std::shared_ptr Transform::Void() { + static auto instance = std::shared_ptr(new Transform(TransformType::kVoid)); + return instance; +} + +std::shared_ptr Transform::Bucket(int32_t num_buckets) { + return std::shared_ptr(new Transform(TransformType::kBucket, num_buckets)); +} + +std::shared_ptr Transform::Truncate(int32_t width) { + return std::shared_ptr(new Transform(TransformType::kTruncate, width)); +} + +Transform::Transform(TransformType transform_type) : transform_type_(transform_type) {} + +Transform::Transform(TransformType transform_type, int32_t param) + : transform_type_(transform_type), param_(param) {} + +TransformType Transform::transform_type() const { return transform_type_; } + +Result> Transform::Bind( + const std::shared_ptr& source_type) const { + auto type_str = TransformTypeToString(transform_type_); + + switch (transform_type_) { + case TransformType::kIdentity: + return std::make_unique(source_type); + + case TransformType::kBucket: { + if (auto param = std::get_if(¶m_)) { + return std::make_unique(source_type, *param); + } + return unexpected({ + .kind = ErrorKind::kInvalidArgument, + .message = std::format( + "Bucket requires int32 param, none found in transform '{}'", type_str), + }); + } -std::string TransformFunction::ToString() const { - return std::format("{}", iceberg::ToString(transform_type_)); + case TransformType::kTruncate: { + if (auto param = std::get_if(¶m_)) { + return std::make_unique(source_type, *param); + } + return unexpected({ + .kind = ErrorKind::kInvalidArgument, + .message = std::format( + "Truncate requires int32 param, none found in transform '{}'", type_str), + }); + } + + case TransformType::kYear: + return std::make_unique(source_type); + case TransformType::kMonth: + return std::make_unique(source_type); + case TransformType::kDay: + return std::make_unique(source_type); + case TransformType::kHour: + return std::make_unique(source_type); + case TransformType::kVoid: + return std::make_unique(source_type); + + default: + return unexpected({ + .kind = ErrorKind::kNotSupported, + .message = std::format("Unsupported transform type: '{}'", type_str), + }); + } } bool TransformFunction::Equals(const TransformFunction& other) const { - return transform_type_ == other.transform_type_; + return transform_type_ == other.transform_type_ && *source_type_ == *other.source_type_; +} + +std::string Transform::ToString() const { + switch (transform_type_) { + case TransformType::kIdentity: + case TransformType::kYear: + case TransformType::kMonth: + case TransformType::kDay: + case TransformType::kHour: + case TransformType::kVoid: + case TransformType::kUnknown: + return std::format("{}", TransformTypeToString(transform_type_)); + case TransformType::kBucket: + case TransformType::kTruncate: + return std::format("{}[{}]", TransformTypeToString(transform_type_), + std::get(param_)); + } } -IdentityTransformFunction::IdentityTransformFunction() - : TransformFunction(TransformType::kIdentity) {} +TransformFunction::TransformFunction(TransformType transform_type, + std::shared_ptr source_type) + : transform_type_(transform_type), source_type_(std::move(source_type)) {} -expected IdentityTransformFunction::Transform( - const ArrowArray& input) { - return unexpected({.kind = ErrorKind::kNotSupported, - .message = "IdentityTransformFunction::Transform"}); +TransformType TransformFunction::transform_type() const { return transform_type_; } + +std::shared_ptr const& TransformFunction::source_type() const { + return source_type_; +} + +bool Transform::Equals(const Transform& other) const { + return transform_type_ == other.transform_type_ && param_ == other.param_; } -expected, Error> TransformFunctionFromString( - std::string_view str) { - if (str == "identity") { - return std::make_unique(); +Result> TransformFromString(std::string_view transform_str) { + if (transform_str == kIdentityName) return Transform::Identity(); + if (transform_str == kYearName) return Transform::Year(); + if (transform_str == kMonthName) return Transform::Month(); + if (transform_str == kDayName) return Transform::Day(); + if (transform_str == kHourName) return Transform::Hour(); + if (transform_str == kVoidName) return Transform::Void(); + + // Match bucket[16] or truncate[4] + static const std::regex param_regex( + std::format(R"(({}|{})\[(\d+)\])", kBucketName, kTruncateName)); + std::string str(transform_str); + std::smatch match; + if (std::regex_match(str, match, param_regex)) { + const std::string type_str = match[1]; + const int32_t param = std::stoi(match[2]); + + if (type_str == kBucketName) { + return Transform::Bucket(param); + } + if (type_str == kTruncateName) { + return Transform::Truncate(param); + } } - return unexpected( - {.kind = ErrorKind::kInvalidArgument, - .message = "Invalid TransformFunction string: " + std::string(str)}); + + return unexpected({ + .kind = ErrorKind::kInvalidArgument, + .message = std::format("Invalid Transform string: {}", transform_str), + }); } } // namespace iceberg diff --git a/src/iceberg/transform.h b/src/iceberg/transform.h index 4e8ecfc0d..05d6799f7 100644 --- a/src/iceberg/transform.h +++ b/src/iceberg/transform.h @@ -23,6 +23,7 @@ #include #include +#include #include "iceberg/arrow_c_data.h" #include "iceberg/expected.h" @@ -56,16 +57,133 @@ enum class TransformType { kVoid, }; +/// \brief Get the relative transform name +ICEBERG_EXPORT constexpr std::string_view TransformTypeToString(TransformType type); + +/// \brief Represents a transform used in partitioning or sorting in Iceberg. +/// +/// This class supports binding to a source type and instantiating the corresponding +/// TransformFunction, as well as serialization-friendly introspection. +class ICEBERG_EXPORT Transform : public util::Formattable { + public: + /// \brief Returns a shared singleton instance of the Identity transform. + /// + /// This transform leaves values unchanged and is commonly used for direct partitioning. + /// \return A shared pointer to the Identity transform. + static std::shared_ptr Identity(); + + /// \brief Creates a shared instance of the Bucket transform. + /// + /// Buckets values using a hash modulo operation. Commonly used for distributing data. + /// \param num_buckets The number of buckets. + /// \return A shared pointer to the Bucket transform. + static std::shared_ptr Bucket(int32_t num_buckets); + + /// \brief Creates a shared instance of the Truncate transform. + /// + /// Truncates values to a fixed width (e.g., for strings or binary data). + /// \param width The width to truncate to. + /// \return A shared pointer to the Truncate transform. + static std::shared_ptr Truncate(int32_t width); + + /// \brief Creates a shared singleton instance of the Year transform. + /// + /// Extracts the year portion from a date or timestamp. + /// \return A shared pointer to the Year transform. + static std::shared_ptr Year(); + + /// \brief Creates a shared singleton instance of the Month transform. + /// + /// Extracts the month portion from a date or timestamp. + /// \return A shared pointer to the Month transform. + static std::shared_ptr Month(); + + /// \brief Creates a shared singleton instance of the Day transform. + /// + /// Extracts the day portion from a date or timestamp. + /// \return A shared pointer to the Day transform. + static std::shared_ptr Day(); + + /// \brief Creates a shared singleton instance of the Hour transform. + /// + /// Extracts the hour portion from a timestamp. + /// \return A shared pointer to the Hour transform. + static std::shared_ptr Hour(); + + /// \brief Creates a shared singleton instance of the Void transform. + /// + /// Ignores values and always returns null. Useful for testing or special cases. + /// \return A shared pointer to the Void transform. + static std::shared_ptr Void(); + + /// \brief Returns the transform type. + TransformType transform_type() const; + + /// \brief Binds this transform to a source type, returning a typed TransformFunction. + /// + /// This creates a concrete transform implementation based on the transform type and + /// parameter. + /// \param source_type The source column type to bind to. + /// \return A TransformFunction instance wrapped in `expected`, or an error on failure. + Result> Bind( + const std::shared_ptr& source_type) const; + + /// \brief Returns a string representation of this transform (e.g., "bucket[16]"). + std::string ToString() const override; + + /// \brief Equality comparison. + friend bool operator==(const Transform& lhs, const Transform& rhs) { + return lhs.Equals(rhs); + } + + /// \brief Inequality comparison. + friend bool operator!=(const Transform& lhs, const Transform& rhs) { + return !(lhs == rhs); + } + + private: + /// \brief Constructs a Transform of the specified type (for non-parametric types). + /// \param transform_type The transform type (e.g., identity, year, day). + explicit Transform(TransformType transform_type); + + /// \brief Constructs a parameterized Transform (e.g., bucket(16), truncate(4)). + /// \param transform_type The transform type. + /// \param param The integer parameter associated with the transform. + Transform(TransformType transform_type, int32_t param); + + /// \brief Checks equality with another Transform instance. + [[nodiscard]] virtual bool Equals(const Transform& other) const; + + TransformType transform_type_; + /// Optional parameter (e.g., num_buckets, width) + std::variant param_; +}; +/// \brief Converts a string representation of a transform into a Transform instance. +/// +/// This function parses the provided string to identify the corresponding transform type +/// (e.g., "identity", "year", "bucket[16]"), and creates a shared pointer to the +/// corresponding Transform object. It supports both simple transforms (like "identity") +/// and parameterized transforms (like "bucket[16]" or "truncate[4]"). +/// +/// \param transform_str The string representation of the transform type. +/// \return A Result containing either a shared pointer to the corresponding Transform +/// instance or an Error if the string does not match any valid transform type. +ICEBERG_EXPORT Result> TransformFromString( + std::string_view transform_str); + /// \brief A transform function used for partitioning. -class ICEBERG_EXPORT TransformFunction : public util::Formattable { +class ICEBERG_EXPORT TransformFunction { public: - explicit TransformFunction(TransformType type); + virtual ~TransformFunction() = default; + TransformFunction(TransformType transform_type, std::shared_ptr source_type); /// \brief Transform an input array to a new array - virtual expected Transform(const ArrowArray& data) = 0; + virtual Result Transform(const ArrowArray& data) = 0; /// \brief Get the transform type - virtual TransformType transform_type() const; - - std::string ToString() const override; + TransformType transform_type() const; + /// \brief Get the source type of transform function + const std::shared_ptr& source_type() const; + /// \brief Get the result type of transform function + virtual Result> ResultType() const = 0; friend bool operator==(const TransformFunction& lhs, const TransformFunction& rhs) { return lhs.Equals(rhs); @@ -80,16 +198,7 @@ class ICEBERG_EXPORT TransformFunction : public util::Formattable { [[nodiscard]] virtual bool Equals(const TransformFunction& other) const; TransformType transform_type_; -}; - -ICEBERG_EXPORT expected, Error> -TransformFunctionFromString(std::string_view str); - -class ICEBERG_EXPORT IdentityTransformFunction : public TransformFunction { - public: - IdentityTransformFunction(); - /// \brief Transform will take an input array and transform it into a new array. - expected Transform(const ArrowArray& input) override; + std::shared_ptr source_type_; }; } // namespace iceberg diff --git a/src/iceberg/transform_function.cc b/src/iceberg/transform_function.cc new file mode 100644 index 000000000..eb01fb8ba --- /dev/null +++ b/src/iceberg/transform_function.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/transform_function.h" + +#include + +#include "iceberg/type.h" + +namespace iceberg { + +IdentityTransform::IdentityTransform(std::shared_ptr const& source_type) + : TransformFunction(TransformType::kIdentity, source_type) {} + +Result IdentityTransform::Transform(const ArrowArray& input) { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "IdentityTransform::Transform"}); +} + +Result> IdentityTransform::ResultType() const { + auto src_type = source_type(); + if (!src_type || !src_type->is_primitive()) { + return unexpected(Error{ + .kind = ErrorKind::kNotSupported, + .message = std::format("{} is not a valid input type for identity transform", + src_type ? src_type->ToString() : "null")}); + } + return src_type; +} + +BucketTransform::BucketTransform(std::shared_ptr const& source_type, + int32_t num_buckets) + : TransformFunction(TransformType::kBucket, source_type), num_buckets_(num_buckets) {} + +Result BucketTransform::Transform(const ArrowArray& input) { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "BucketTransform::Transform"}); +} + +Result> BucketTransform::ResultType() const { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "BucketTransform::result_type"}); +} + +TruncateTransform::TruncateTransform(std::shared_ptr const& source_type, + int32_t width) + : TransformFunction(TransformType::kTruncate, source_type), width_(width) {} + +Result TruncateTransform::Transform(const ArrowArray& input) { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "TruncateTransform::Transform"}); +} + +Result> TruncateTransform::ResultType() const { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "TruncateTransform::result_type"}); +} + +YearTransform::YearTransform(std::shared_ptr const& source_type) + : TransformFunction(TransformType::kTruncate, source_type) {} + +Result YearTransform::Transform(const ArrowArray& input) { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "YearTransform::Transform"}); +} + +Result> YearTransform::ResultType() const { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "YearTransform::result_type"}); +} + +MonthTransform::MonthTransform(std::shared_ptr const& source_type) + : TransformFunction(TransformType::kMonth, source_type) {} + +Result MonthTransform::Transform(const ArrowArray& input) { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "MonthTransform::Transform"}); +} + +Result> MonthTransform::ResultType() const { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "MonthTransform::result_type"}); +} + +DayTransform::DayTransform(std::shared_ptr const& source_type) + : TransformFunction(TransformType::kDay, source_type) {} + +Result DayTransform::Transform(const ArrowArray& input) { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "DayTransform::Transform"}); +} + +Result> DayTransform::ResultType() const { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "DayTransform::result_type"}); +} + +HourTransform::HourTransform(std::shared_ptr const& source_type) + : TransformFunction(TransformType::kHour, source_type) {} + +Result HourTransform::Transform(const ArrowArray& input) { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "HourTransform::Transform"}); +} + +Result> HourTransform::ResultType() const { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "HourTransform::result_type"}); +} + +VoidTransform::VoidTransform(std::shared_ptr const& source_type) + : TransformFunction(TransformType::kVoid, source_type) {} + +Result VoidTransform::Transform(const ArrowArray& input) { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "VoidTransform::Transform"}); +} + +Result> VoidTransform::ResultType() const { + return unexpected( + {.kind = ErrorKind::kNotImplemented, .message = "VoidTransform::result_type"}); +} + +} // namespace iceberg diff --git a/src/iceberg/transform_function.h b/src/iceberg/transform_function.h new file mode 100644 index 000000000..eb844324c --- /dev/null +++ b/src/iceberg/transform_function.h @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +/// \file iceberg/transform_function.h + +#include "iceberg/transform.h" + +namespace iceberg { +/// \brief Identity transform that returns the input unchanged. +class IdentityTransform : public TransformFunction { + public: + /// \param source_type Type of the input data. + explicit IdentityTransform(std::shared_ptr const& source_type); + + /// \brief Returns the input array without modification. + Result Transform(const ArrowArray& input) override; + + /// \brief Returns the same type as the source type if it is valid. + Result> ResultType() const override; +}; + +/// \brief Bucket transform that hashes input values into N buckets. +class BucketTransform : public TransformFunction { + public: + /// \param source_type Type of the input data. + /// \param num_buckets Number of buckets to hash into. + BucketTransform(std::shared_ptr const& source_type, int32_t num_buckets); + + /// \brief Applies the bucket hash function to the input array. + Result Transform(const ArrowArray& input) override; + + /// \brief Returns INT32 as the output type. + Result> ResultType() const override; + + private: + int32_t num_buckets_; +}; + +/// \brief Truncate transform that truncates values to a specified width. +class TruncateTransform : public TransformFunction { + public: + /// \param source_type Type of the input data. + /// \param width The width to truncate to (e.g., for strings or numbers). + TruncateTransform(std::shared_ptr const& source_type, int32_t width); + + /// \brief Truncates values in the input array to the specified width. + Result Transform(const ArrowArray& input) override; + + /// \brief Returns the same type as source_type. + Result> ResultType() const override; + + private: + int32_t width_; +}; + +/// \brief Year transform that extracts the year component from timestamp inputs. +class YearTransform : public TransformFunction { + public: + /// \param source_type Must be a timestamp type. + explicit YearTransform(std::shared_ptr const& source_type); + + /// \brief Extracts the year from each timestamp in the input array. + Result Transform(const ArrowArray& input) override; + + /// \brief Returns INT32 as the output type. + Result> ResultType() const override; +}; + +/// \brief Month transform that extracts the month component from timestamp inputs. +class MonthTransform : public TransformFunction { + public: + /// \param source_type Must be a timestamp type. + explicit MonthTransform(std::shared_ptr const& source_type); + + /// \brief Extracts the month (1-12) from each timestamp in the input array. + Result Transform(const ArrowArray& input) override; + + /// \brief Returns INT32 as the output type. + Result> ResultType() const override; +}; + +/// \brief Day transform that extracts the day of the month from timestamp inputs. +class DayTransform : public TransformFunction { + public: + /// \param source_type Must be a timestamp type. + explicit DayTransform(std::shared_ptr const& source_type); + + /// \brief Extracts the day (1-31) from each timestamp in the input array. + Result Transform(const ArrowArray& input) override; + + /// \brief Returns INT32 as the output type. + Result> ResultType() const override; +}; + +/// \brief Hour transform that extracts the hour component from timestamp inputs. +class HourTransform : public TransformFunction { + public: + /// \param source_type Must be a timestamp type. + explicit HourTransform(std::shared_ptr const& source_type); + + /// \brief Extracts the hour (0-23) from each timestamp in the input array. + Result Transform(const ArrowArray& input) override; + + /// \brief Returns INT32 as the output type. + Result> ResultType() const override; +}; + +/// \brief Void transform that discards the input and always returns null. +class VoidTransform : public TransformFunction { + public: + /// \param source_type Input type (ignored). + explicit VoidTransform(std::shared_ptr const& source_type); + + /// \brief Returns an all-null array of the same length as the input. + Result Transform(const ArrowArray& input) override; + + /// \brief Returns null type or a sentinel type indicating void. + Result> ResultType() const override; +}; + +} // namespace iceberg diff --git a/src/iceberg/type_fwd.h b/src/iceberg/type_fwd.h index bc55aebb7..519164ef9 100644 --- a/src/iceberg/type_fwd.h +++ b/src/iceberg/type_fwd.h @@ -101,6 +101,7 @@ class SortField; class SortOrder; class StructLike; class TableMetadata; +class Transform; enum class TransformType; class TransformFunction; diff --git a/test/json_internal_test.cc b/test/json_internal_test.cc index 75c020a0d..ff6361d67 100644 --- a/test/json_internal_test.cc +++ b/test/json_internal_test.cc @@ -70,7 +70,7 @@ void TestJsonConversion(const T& obj, const nlohmann::json& expected_json) { } // namespace TEST(JsonInternalTest, SortField) { - auto identity_transform = std::make_shared(); + auto identity_transform = Transform::Identity(); // Test for SortField with ascending order SortField sort_field_asc(5, identity_transform, SortDirection::kAscending, @@ -88,7 +88,7 @@ TEST(JsonInternalTest, SortField) { } TEST(JsonInternalTest, SortOrder) { - auto identity_transform = std::make_shared(); + auto identity_transform = Transform::Identity(); SortField st_ts(5, identity_transform, SortDirection::kAscending, NullOrder::kFirst); SortField st_bar(7, identity_transform, SortDirection::kDescending, NullOrder::kLast); SortOrder sort_order(100, {st_ts, st_bar}); @@ -102,7 +102,7 @@ TEST(JsonInternalTest, SortOrder) { } TEST(JsonInternalTest, PartitionField) { - auto identity_transform = std::make_shared(); + auto identity_transform = Transform::Identity(); PartitionField field(3, 101, "region", identity_transform); nlohmann::json expected_json = R"({"source-id":3,"field-id":101,"transform":"identity","name":"region"})"_json; @@ -125,7 +125,7 @@ TEST(JsonPartitionTest, PartitionSpec) { SchemaField(3, "region", std::make_shared(), false), SchemaField(5, "ts", std::make_shared(), false)}); - auto identity_transform = std::make_shared(); + auto identity_transform = Transform::Identity(); PartitionSpec spec(schema, 1, {PartitionField(3, 101, "region", identity_transform), PartitionField(5, 102, "ts", identity_transform)}); diff --git a/test/partition_field_test.cc b/test/partition_field_test.cc index 6c21c1fa3..11f22477e 100644 --- a/test/partition_field_test.cc +++ b/test/partition_field_test.cc @@ -22,27 +22,16 @@ #include #include +#include #include "iceberg/transform.h" #include "iceberg/util/formatter.h" namespace iceberg { -namespace { -class TestTransformFunction : public TransformFunction { - public: - TestTransformFunction() : TransformFunction(TransformType::kUnknown) {} - expected Transform(const ArrowArray& input) override { - return unexpected( - Error{.kind = ErrorKind::kNotSupported, .message = "test transform function"}); - } -}; - -} // namespace - TEST(PartitionFieldTest, Basics) { { - const auto transform = std::make_shared(); + auto transform = Transform::Identity(); PartitionField field(1, 1000, "pt", transform); EXPECT_EQ(1, field.source_id()); EXPECT_EQ(1000, field.field_id()); @@ -54,13 +43,13 @@ TEST(PartitionFieldTest, Basics) { } TEST(PartitionFieldTest, Equality) { - auto test_transform = std::make_shared(); - auto identity_transform = std::make_shared(); + const auto bucket_transform = Transform::Bucket(8); + const auto identity_transform = Transform::Identity(); - PartitionField field1(1, 10000, "pt", test_transform); - PartitionField field2(2, 10000, "pt", test_transform); - PartitionField field3(1, 10001, "pt", test_transform); - PartitionField field4(1, 10000, "pt2", test_transform); + PartitionField field1(1, 10000, "pt", bucket_transform); + PartitionField field2(2, 10000, "pt", bucket_transform); + PartitionField field3(1, 10001, "pt", bucket_transform); + PartitionField field4(1, 10000, "pt2", bucket_transform); PartitionField field5(1, 10000, "pt", identity_transform); ASSERT_EQ(field1, field1); diff --git a/test/partition_spec_test.cc b/test/partition_spec_test.cc index 9e486557a..b60e2886d 100644 --- a/test/partition_spec_test.cc +++ b/test/partition_spec_test.cc @@ -38,7 +38,7 @@ TEST(PartitionSpecTest, Basics) { SchemaField field2(7, "bar", std::make_shared(), true); auto const schema = std::make_shared(100, std::vector{field1, field2}); - auto identity_transform = std::make_shared(); + auto identity_transform = Transform::Identity(); PartitionField pt_field1(5, 1000, "day", identity_transform); PartitionField pt_field2(5, 1001, "hour", identity_transform); PartitionSpec spec(schema, 100, {pt_field1, pt_field2}); @@ -61,7 +61,7 @@ TEST(PartitionSpecTest, Equality) { SchemaField field1(5, "ts", std::make_shared(), true); SchemaField field2(7, "bar", std::make_shared(), true); auto const schema = std::make_shared(100, std::vector{field1, field2}); - auto identity_transform = std::make_shared(); + auto identity_transform = Transform::Identity(); PartitionField pt_field1(5, 1000, "day", identity_transform); PartitionField pt_field2(7, 1001, "hour", identity_transform); PartitionField pt_field3(7, 1001, "hour", identity_transform); diff --git a/test/sort_field_test.cc b/test/sort_field_test.cc index 2141a3db1..0a8c407b7 100644 --- a/test/sort_field_test.cc +++ b/test/sort_field_test.cc @@ -24,25 +24,14 @@ #include #include "iceberg/transform.h" +#include "iceberg/type.h" #include "iceberg/util/formatter.h" namespace iceberg { -namespace { -class TestTransformFunction : public TransformFunction { - public: - TestTransformFunction() : TransformFunction(TransformType::kUnknown) {} - expected Transform(const ArrowArray& input) override { - return unexpected( - Error{.kind = ErrorKind::kNotSupported, .message = "test transform function"}); - } -}; - -} // namespace - TEST(SortFieldTest, Basics) { { - const auto transform = std::make_shared(); + const auto transform = Transform::Identity(); SortField field(1, transform, SortDirection::kAscending, NullOrder::kFirst); EXPECT_EQ(1, field.source_id()); EXPECT_EQ(*transform, *field.transform()); @@ -60,14 +49,14 @@ TEST(SortFieldTest, Basics) { } TEST(SortFieldTest, Equality) { - auto test_transform = std::make_shared(); - auto identity_transform = std::make_shared(); + const auto bucket_transform = Transform::Bucket(8); + const auto identity_transform = Transform::Identity(); - SortField field1(1, test_transform, SortDirection::kAscending, NullOrder::kFirst); - SortField field2(2, test_transform, SortDirection::kAscending, NullOrder::kFirst); + SortField field1(1, bucket_transform, SortDirection::kAscending, NullOrder::kFirst); + SortField field2(2, bucket_transform, SortDirection::kAscending, NullOrder::kFirst); SortField field3(1, identity_transform, SortDirection::kAscending, NullOrder::kFirst); - SortField field4(1, test_transform, SortDirection::kDescending, NullOrder::kFirst); - SortField field5(1, test_transform, SortDirection::kAscending, NullOrder::kLast); + SortField field4(1, bucket_transform, SortDirection::kDescending, NullOrder::kFirst); + SortField field5(1, bucket_transform, SortDirection::kAscending, NullOrder::kLast); ASSERT_EQ(field1, field1); ASSERT_NE(field1, field2); diff --git a/test/sort_order_test.cc b/test/sort_order_test.cc index 4f12e5f2f..310a013b0 100644 --- a/test/sort_order_test.cc +++ b/test/sort_order_test.cc @@ -31,24 +31,12 @@ namespace iceberg { -namespace { -class TestTransformFunction : public TransformFunction { - public: - TestTransformFunction() : TransformFunction(TransformType::kUnknown) {} - expected Transform(const ArrowArray& input) override { - return unexpected( - Error{.kind = ErrorKind::kNotSupported, .message = "test transform function"}); - } -}; - -} // namespace - TEST(SortOrderTest, Basics) { { SchemaField field1(5, "ts", std::make_shared(), true); SchemaField field2(7, "bar", std::make_shared(), true); - auto identity_transform = std::make_shared(); + auto identity_transform = Transform::Identity(); SortField st_field1(5, identity_transform, SortDirection::kAscending, NullOrder::kFirst); SortField st_field2(7, identity_transform, SortDirection::kDescending, @@ -73,13 +61,13 @@ TEST(SortOrderTest, Basics) { TEST(SortOrderTest, Equality) { SchemaField field1(5, "ts", std::make_shared(), true); SchemaField field2(7, "bar", std::make_shared(), true); - auto test_transform = std::make_shared(); - auto identity_transform = std::make_shared(); + auto bucket_transform = Transform::Bucket(8); + auto identity_transform = Transform::Identity(); SortField st_field1(5, identity_transform, SortDirection::kAscending, NullOrder::kFirst); SortField st_field2(7, identity_transform, SortDirection::kDescending, NullOrder::kFirst); - SortField st_field3(7, test_transform, SortDirection::kAscending, NullOrder::kFirst); + SortField st_field3(7, bucket_transform, SortDirection::kAscending, NullOrder::kFirst); SortOrder sort_order1(100, {st_field1, st_field2}); SortOrder sort_order2(100, {st_field2, st_field3}); SortOrder sort_order3(100, {st_field1, st_field3}); diff --git a/test/transform_test.cc b/test/transform_test.cc index fc0086b7d..9d3f36d45 100644 --- a/test/transform_test.cc +++ b/test/transform_test.cc @@ -25,30 +25,95 @@ #include #include +#include "iceberg/type.h" #include "iceberg/util/formatter.h" namespace iceberg { -TEST(TransformTest, TransformFunction) { - class TestTransformFunction : public TransformFunction { - public: - TestTransformFunction() : TransformFunction(TransformType::kUnknown) {} - expected Transform(const ArrowArray& input) override { - return unexpected( - Error{.kind = ErrorKind::kNotSupported, .message = "test transform function"}); - } - }; +TEST(TransformTest, Transform) { + auto transform = Transform::Identity(); + EXPECT_EQ(TransformType::kIdentity, transform->transform_type()); + EXPECT_EQ("identity", transform->ToString()); + EXPECT_EQ("identity", std::format("{}", *transform)); - TestTransformFunction transform; - EXPECT_EQ(TransformType::kUnknown, transform.transform_type()); - EXPECT_EQ("unknown", transform.ToString()); - EXPECT_EQ("unknown", std::format("{}", transform)); + auto source_type = std::make_shared(); + auto identity_transform = transform->Bind(source_type); + ASSERT_TRUE(identity_transform); ArrowArray arrow_array; - auto result = transform.Transform(arrow_array); + auto result = identity_transform.value()->Transform(arrow_array); ASSERT_FALSE(result); - EXPECT_EQ(ErrorKind::kNotSupported, result.error().kind); - EXPECT_EQ("test transform function", result.error().message); + EXPECT_EQ(ErrorKind::kNotImplemented, result.error().kind); + EXPECT_EQ("IdentityTransform::Transform", result.error().message); +} + +TEST(TransformFunctionTest, CreateBucketTransform) { + constexpr int32_t bucket_count = 8; + auto transform = Transform::Bucket(bucket_count); + EXPECT_EQ("bucket[8]", transform->ToString()); + EXPECT_EQ("bucket[8]", std::format("{}", *transform)); + + const auto transformPtr = transform->Bind(std::make_shared()); + ASSERT_TRUE(transformPtr); + EXPECT_EQ(transformPtr.value()->transform_type(), TransformType::kBucket); +} + +TEST(TransformFunctionTest, CreateTruncateTransform) { + constexpr int32_t width = 16; + auto transform = Transform::Truncate(width); + EXPECT_EQ("truncate[16]", transform->ToString()); + EXPECT_EQ("truncate[16]", std::format("{}", *transform)); + + auto transformPtr = transform->Bind(std::make_shared()); + EXPECT_EQ(transformPtr.value()->transform_type(), TransformType::kTruncate); +} +TEST(TransformFromStringTest, PositiveCases) { + struct Case { + std::string str; + TransformType type; + std::optional param; + }; + + const std::vector cases = { + {.str = "identity", .type = TransformType::kIdentity, .param = std::nullopt}, + {.str = "year", .type = TransformType::kYear, .param = std::nullopt}, + {.str = "month", .type = TransformType::kMonth, .param = std::nullopt}, + {.str = "day", .type = TransformType::kDay, .param = std::nullopt}, + {.str = "hour", .type = TransformType::kHour, .param = std::nullopt}, + {.str = "void", .type = TransformType::kVoid, .param = std::nullopt}, + {.str = "bucket[16]", .type = TransformType::kBucket, .param = 16}, + {.str = "truncate[32]", .type = TransformType::kTruncate, .param = 32}, + }; + for (const auto& c : cases) { + auto result = TransformFromString(c.str); + ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str; + + const auto& transform = result.value(); + EXPECT_EQ(transform->transform_type(), c.type); + if (c.param.has_value()) { + EXPECT_EQ(transform->ToString(), + std::format("{}[{}]", TransformTypeToString(c.type), *c.param)); + } else { + EXPECT_EQ(transform->ToString(), TransformTypeToString(c.type)); + } + } +} + +TEST(TransformFromStringTest, NegativeCases) { + constexpr std::array invalid_cases = { + "bucket", // missing param + "bucket[]", // empty param + "bucket[abc]", // invalid number + "unknown", // unsupported transform + "bucket[16", // missing closing bracket + "truncate[1]extra" // extra characters + }; + + for (const auto& str : invalid_cases) { + auto result = TransformFromString(str); + EXPECT_FALSE(result.has_value()) << "Unexpected success for: " << str; + EXPECT_EQ(result.error().kind, ErrorKind::kInvalidArgument); + } } } // namespace iceberg