diff --git a/src/iceberg/transform.cc b/src/iceberg/transform.cc index dc6f3dc88..7898fc638 100644 --- a/src/iceberg/transform.cc +++ b/src/iceberg/transform.cc @@ -115,11 +115,11 @@ Result> Transform::Bind( switch (transform_type_) { case TransformType::kIdentity: - return std::make_unique(source_type); + return IdentityTransform::Make(source_type); case TransformType::kBucket: { if (auto param = std::get_if(¶m_)) { - return std::make_unique(source_type, *param); + return BucketTransform::Make(source_type, *param); } return InvalidArgument("Bucket requires int32 param, none found in transform '{}'", type_str); @@ -127,22 +127,22 @@ Result> Transform::Bind( case TransformType::kTruncate: { if (auto param = std::get_if(¶m_)) { - return std::make_unique(source_type, *param); + return TruncateTransform::Make(source_type, *param); } return InvalidArgument( "Truncate requires int32 param, none found in transform '{}'", type_str); } case TransformType::kYear: - return std::make_unique(source_type); + return YearTransform::Make(source_type); case TransformType::kMonth: - return std::make_unique(source_type); + return MonthTransform::Make(source_type); case TransformType::kDay: - return std::make_unique(source_type); + return DayTransform::Make(source_type); case TransformType::kHour: - return std::make_unique(source_type); + return HourTransform::Make(source_type); case TransformType::kVoid: - return std::make_unique(source_type); + return VoidTransform::Make(source_type); default: return NotSupported("Unsupported transform type: '{}'", type_str); diff --git a/src/iceberg/transform_function.cc b/src/iceberg/transform_function.cc index 6aa49bff6..9ddf6e9f7 100644 --- a/src/iceberg/transform_function.cc +++ b/src/iceberg/transform_function.cc @@ -31,12 +31,16 @@ Result IdentityTransform::Transform(const ArrowArray& input) { } Result> IdentityTransform::ResultType() const { - auto src_type = source_type(); - if (!src_type || !src_type->is_primitive()) { + return source_type(); +} + +Result> IdentityTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type || !source_type->is_primitive()) { return NotSupported("{} is not a valid input type for identity transform", - src_type ? src_type->ToString() : "null"); + source_type ? source_type->ToString() : "null"); } - return src_type; + return std::make_unique(source_type); } BucketTransform::BucketTransform(std::shared_ptr const& source_type, @@ -48,7 +52,35 @@ Result BucketTransform::Transform(const ArrowArray& input) { } Result> BucketTransform::ResultType() const { - return NotImplemented("BucketTransform::result_type"); + return iceberg::int32(); +} + +Result> BucketTransform::Make( + std::shared_ptr const& source_type, int32_t num_buckets) { + if (!source_type) { + return NotSupported("null is not a valid input type for bucket transform"); + } + switch (source_type->type_id()) { + case TypeId::kInt: + case TypeId::kLong: + case TypeId::kDecimal: + case TypeId::kDate: + case TypeId::kTime: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + case TypeId::kString: + case TypeId::kUuid: + case TypeId::kFixed: + case TypeId::kBinary: + break; + default: + return NotSupported("{} is not a valid input type for bucket transform", + source_type->ToString()); + } + if (num_buckets <= 0) { + return InvalidArgument("Number of buckets must be positive, got {}", num_buckets); + } + return std::make_unique(source_type, num_buckets); } TruncateTransform::TruncateTransform(std::shared_ptr const& source_type, @@ -60,7 +92,29 @@ Result TruncateTransform::Transform(const ArrowArray& input) { } Result> TruncateTransform::ResultType() const { - return NotImplemented("TruncateTransform::result_type"); + return source_type(); +} + +Result> TruncateTransform::Make( + std::shared_ptr const& source_type, int32_t width) { + if (!source_type) { + return NotSupported("null is not a valid input type for truncate transform"); + } + switch (source_type->type_id()) { + case TypeId::kInt: + case TypeId::kLong: + case TypeId::kDecimal: + case TypeId::kString: + case TypeId::kBinary: + break; + default: + return NotSupported("{} is not a valid input type for truncate transform", + source_type->ToString()); + } + if (width <= 0) { + return InvalidArgument("Width must be positive, got {}", width); + } + return std::make_unique(source_type, width); } YearTransform::YearTransform(std::shared_ptr const& source_type) @@ -71,7 +125,24 @@ Result YearTransform::Transform(const ArrowArray& input) { } Result> YearTransform::ResultType() const { - return NotImplemented("YearTransform::result_type"); + return iceberg::int32(); +} + +Result> YearTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type) { + return NotSupported("null is not a valid input type for year transform"); + } + switch (source_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + break; + default: + return NotSupported("{} is not a valid input type for year transform", + source_type->ToString()); + } + return std::make_unique(source_type); } MonthTransform::MonthTransform(std::shared_ptr const& source_type) @@ -82,7 +153,24 @@ Result MonthTransform::Transform(const ArrowArray& input) { } Result> MonthTransform::ResultType() const { - return NotImplemented("MonthTransform::result_type"); + return iceberg::int32(); +} + +Result> MonthTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type) { + return NotSupported("null is not a valid input type for month transform"); + } + switch (source_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + break; + default: + return NotSupported("{} is not a valid input type for month transform", + source_type->ToString()); + } + return std::make_unique(source_type); } DayTransform::DayTransform(std::shared_ptr const& source_type) @@ -92,8 +180,23 @@ Result DayTransform::Transform(const ArrowArray& input) { return NotImplemented("DayTransform::Transform"); } -Result> DayTransform::ResultType() const { - return NotImplemented("DayTransform::result_type"); +Result> DayTransform::ResultType() const { return iceberg::date(); } + +Result> DayTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type) { + return NotSupported("null is not a valid input type for day transform"); + } + switch (source_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + break; + default: + return NotSupported("{} is not a valid input type for day transform", + source_type->ToString()); + } + return std::make_unique(source_type); } HourTransform::HourTransform(std::shared_ptr const& source_type) @@ -104,7 +207,23 @@ Result HourTransform::Transform(const ArrowArray& input) { } Result> HourTransform::ResultType() const { - return NotImplemented("HourTransform::result_type"); + return iceberg::int32(); +} + +Result> HourTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type) { + return NotSupported("null is not a valid input type for hour transform"); + } + switch (source_type->type_id()) { + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + break; + default: + return NotSupported("{} is not a valid input type for hour transform", + source_type->ToString()); + } + return std::make_unique(source_type); } VoidTransform::VoidTransform(std::shared_ptr const& source_type) @@ -114,8 +233,14 @@ Result VoidTransform::Transform(const ArrowArray& input) { return NotImplemented("VoidTransform::Transform"); } -Result> VoidTransform::ResultType() const { - return NotImplemented("VoidTransform::result_type"); +Result> VoidTransform::ResultType() const { return source_type(); } + +Result> VoidTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type) { + return NotSupported("null is not a valid input type for void transform"); + } + return std::make_unique(source_type); } } // namespace iceberg diff --git a/src/iceberg/transform_function.h b/src/iceberg/transform_function.h index eb844324c..7fffd61f0 100644 --- a/src/iceberg/transform_function.h +++ b/src/iceberg/transform_function.h @@ -35,6 +35,12 @@ class IdentityTransform : public TransformFunction { /// \brief Returns the same type as the source type if it is valid. Result> ResultType() const override; + + /// \brief Create an IdentityTransform. + /// \param source_type Type of the input data. + /// \return A Result containing the IdentityTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; /// \brief Bucket transform that hashes input values into N buckets. @@ -50,6 +56,13 @@ class BucketTransform : public TransformFunction { /// \brief Returns INT32 as the output type. Result> ResultType() const override; + /// \brief Create a BucketTransform. + /// \param source_type Type of the input data. + /// \param num_buckets Number of buckets to hash into. + /// \return A Result containing the BucketTransform or an error. + static Result> Make( + std::shared_ptr const& source_type, int32_t num_buckets); + private: int32_t num_buckets_; }; @@ -67,6 +80,13 @@ class TruncateTransform : public TransformFunction { /// \brief Returns the same type as source_type. Result> ResultType() const override; + /// \brief Create a TruncateTransform. + /// \param source_type Type of the input data. + /// \param width The width to truncate to. + /// \return A Result containing the TruncateTransform or an error. + static Result> Make( + std::shared_ptr const& source_type, int32_t width); + private: int32_t width_; }; @@ -82,6 +102,12 @@ class YearTransform : public TransformFunction { /// \brief Returns INT32 as the output type. Result> ResultType() const override; + + /// \brief Create a YearTransform. + /// \param source_type Type of the input data. + /// \return A Result containing the YearTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; /// \brief Month transform that extracts the month component from timestamp inputs. @@ -95,6 +121,12 @@ class MonthTransform : public TransformFunction { /// \brief Returns INT32 as the output type. Result> ResultType() const override; + + /// \brief Create a MonthTransform. + /// \param source_type Type of the input data. + /// \return A Result containing the MonthTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; /// \brief Day transform that extracts the day of the month from timestamp inputs. @@ -108,6 +140,12 @@ class DayTransform : public TransformFunction { /// \brief Returns INT32 as the output type. Result> ResultType() const override; + + /// \brief Create a DayTransform. + /// \param source_type Type of the input data. + /// \return A Result containing the DayTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; /// \brief Hour transform that extracts the hour component from timestamp inputs. @@ -121,6 +159,12 @@ class HourTransform : public TransformFunction { /// \brief Returns INT32 as the output type. Result> ResultType() const override; + + /// \brief Create a HourTransform. + /// \param source_type Type of the input data. + /// \return A Result containing the HourTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; /// \brief Void transform that discards the input and always returns null. @@ -134,6 +178,12 @@ class VoidTransform : public TransformFunction { /// \brief Returns null type or a sentinel type indicating void. Result> ResultType() const override; + + /// \brief Create a VoidTransform. + /// \param source_type Input type (ignored). + /// \return A Result containing the VoidTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; } // namespace iceberg diff --git a/test/transform_test.cc b/test/transform_test.cc index 2f06eb7a9..33149d14d 100644 --- a/test/transform_test.cc +++ b/test/transform_test.cc @@ -117,4 +117,80 @@ TEST(TransformFromStringTest, NegativeCases) { } } +TEST(TransformResultTypeTest, PositiveCases) { + struct Case { + std::string str; + std::shared_ptr source_type; + std::shared_ptr expected_result_type; + }; + + const std::vector cases = { + {.str = "identity", + .source_type = iceberg::string(), + .expected_result_type = iceberg::string()}, + {.str = "year", + .source_type = iceberg::timestamp(), + .expected_result_type = iceberg::int32()}, + {.str = "month", + .source_type = iceberg::timestamp(), + .expected_result_type = iceberg::int32()}, + {.str = "day", + .source_type = iceberg::timestamp(), + .expected_result_type = iceberg::date()}, + {.str = "hour", + .source_type = iceberg::timestamp(), + .expected_result_type = iceberg::int32()}, + {.str = "void", + .source_type = iceberg::string(), + .expected_result_type = iceberg::string()}, + {.str = "bucket[16]", + .source_type = iceberg::string(), + .expected_result_type = iceberg::int32()}, + {.str = "truncate[32]", + .source_type = iceberg::string(), + .expected_result_type = iceberg::string()}, + }; + + for (const auto& c : cases) { + auto result = TransformFromString(c.str); + ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str; + + const auto& transform = result.value(); + const auto transformPtr = transform->Bind(c.source_type); + ASSERT_TRUE(transformPtr.has_value()) << "Failed to bind: " << c.str; + + auto result_type = transformPtr.value()->ResultType(); + ASSERT_TRUE(result_type.has_value()) << "Failed to get result type for: " << c.str; + EXPECT_EQ(result_type.value()->type_id(), c.expected_result_type->type_id()) + << "Unexpected result type for: " << c.str; + } +} + +TEST(TransformResultTypeTest, NegativeCases) { + struct Case { + std::string str; + std::shared_ptr source_type; + }; + + const std::vector cases = { + {.str = "identity", .source_type = nullptr}, + {.str = "year", .source_type = iceberg::string()}, + {.str = "month", .source_type = iceberg::string()}, + {.str = "day", .source_type = iceberg::string()}, + {.str = "hour", .source_type = iceberg::string()}, + {.str = "void", .source_type = nullptr}, + {.str = "bucket[16]", .source_type = iceberg::float32()}, + {.str = "truncate[32]", .source_type = iceberg::float64()}}; + + for (const auto& c : cases) { + auto result = TransformFromString(c.str); + ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str; + + const auto& transform = result.value(); + auto transformPtr = transform->Bind(c.source_type); + + ASSERT_THAT(transformPtr, IsError(ErrorKind::kNotSupported)); + } +} + } // namespace iceberg