From 9e0fac3d6269562fa7fff9864e83d4430cc97ed4 Mon Sep 17 00:00:00 2001 From: Junwang Zhao Date: Sat, 28 Jun 2025 21:08:37 +0800 Subject: [PATCH 1/4] feat: implement transform ResultType --- src/iceberg/transform_function.cc | 99 ++++++++++++++++++++++++++++--- test/transform_test.cc | 77 ++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 7 deletions(-) diff --git a/src/iceberg/transform_function.cc b/src/iceberg/transform_function.cc index 6aa49bff6..32b1a1f85 100644 --- a/src/iceberg/transform_function.cc +++ b/src/iceberg/transform_function.cc @@ -48,7 +48,27 @@ Result BucketTransform::Transform(const ArrowArray& input) { } Result> BucketTransform::ResultType() const { - return NotImplemented("BucketTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for bucket transform"); + } + switch (src_type->type_id()) { + case TypeId::kInt: + case TypeId::kLong: + case TypeId::kDecimal: + case TypeId::kDate: + case TypeId::kTime: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + case TypeId::kString: + case TypeId::kUuid: + case TypeId::kFixed: + case TypeId::kBinary: + return std::make_shared(); + default: + return NotSupported("{} is not a valid input type for bucket transform", + src_type->ToString()); + } } TruncateTransform::TruncateTransform(std::shared_ptr const& source_type, @@ -60,7 +80,21 @@ Result TruncateTransform::Transform(const ArrowArray& input) { } Result> TruncateTransform::ResultType() const { - return NotImplemented("TruncateTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for truncate transform"); + } + switch (src_type->type_id()) { + case TypeId::kInt: + case TypeId::kLong: + case TypeId::kDecimal: + case TypeId::kString: + case TypeId::kBinary: + return src_type; + default: + return NotSupported("{} is not a valid input type for truncate transform", + src_type->ToString()); + } } YearTransform::YearTransform(std::shared_ptr const& source_type) @@ -71,7 +105,19 @@ Result YearTransform::Transform(const ArrowArray& input) { } Result> YearTransform::ResultType() const { - return NotImplemented("YearTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for year transform"); + } + switch (src_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + return std::make_shared(); + default: + return NotSupported("{} is not a valid input type for year transform", + src_type->ToString()); + } } MonthTransform::MonthTransform(std::shared_ptr const& source_type) @@ -82,7 +128,19 @@ Result MonthTransform::Transform(const ArrowArray& input) { } Result> MonthTransform::ResultType() const { - return NotImplemented("MonthTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for month transform"); + } + switch (src_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + return std::make_shared(); + default: + return NotSupported("{} is not a valid input type for month transform", + src_type->ToString()); + } } DayTransform::DayTransform(std::shared_ptr const& source_type) @@ -93,7 +151,19 @@ Result DayTransform::Transform(const ArrowArray& input) { } Result> DayTransform::ResultType() const { - return NotImplemented("DayTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for day transform"); + } + switch (src_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + return std::make_shared(); + default: + return NotSupported("{} is not a valid input type for day transform", + src_type->ToString()); + } } HourTransform::HourTransform(std::shared_ptr const& source_type) @@ -104,7 +174,18 @@ Result HourTransform::Transform(const ArrowArray& input) { } Result> HourTransform::ResultType() const { - return NotImplemented("HourTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for hour transform"); + } + switch (src_type->type_id()) { + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + return std::make_shared(); + default: + return NotSupported("{} is not a valid input type for hour transform", + src_type->ToString()); + } } VoidTransform::VoidTransform(std::shared_ptr const& source_type) @@ -115,7 +196,11 @@ Result VoidTransform::Transform(const ArrowArray& input) { } Result> VoidTransform::ResultType() const { - return NotImplemented("VoidTransform::result_type"); + auto src_type = source_type(); + if (!src_type) { + return NotSupported("null is not a valid input type for void transform"); + } + return src_type; } } // namespace iceberg diff --git a/test/transform_test.cc b/test/transform_test.cc index 2f06eb7a9..a1de10b0e 100644 --- a/test/transform_test.cc +++ b/test/transform_test.cc @@ -117,4 +117,81 @@ TEST(TransformFromStringTest, NegativeCases) { } } +TEST(TransformResultTypeTest, PositiveCases) { + struct Case { + std::string str; + std::shared_ptr source_type; + std::shared_ptr expected_result_type; + }; + + const std::vector cases = { + {.str = "identity", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "year", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "month", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "day", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "hour", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "void", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "bucket[16]", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + {.str = "truncate[32]", + .source_type = std::make_shared(), + .expected_result_type = std::make_shared()}, + }; + + for (const auto& c : cases) { + auto result = TransformFromString(c.str); + ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str; + + const auto& transform = result.value(); + const auto transformPtr = transform->Bind(c.source_type); + ASSERT_TRUE(transformPtr.has_value()) << "Failed to bind: " << c.str; + + auto result_type = transformPtr.value()->ResultType(); + ASSERT_TRUE(result_type.has_value()) << "Failed to get result type for: " << c.str; + EXPECT_EQ(result_type.value()->type_id(), c.expected_result_type->type_id()) + << "Unexpected result type for: " << c.str; + } +} + +TEST(TransformResultTypeTest, NegativeCases) { + struct Case { + std::string str; + std::shared_ptr source_type; + }; + + const std::vector cases = { + {.str = "identity", .source_type = nullptr}, + {.str = "year", .source_type = std::make_shared()}, + {.str = "month", .source_type = std::make_shared()}, + {.str = "day", .source_type = std::make_shared()}, + {.str = "hour", .source_type = std::make_shared()}, + {.str = "void", .source_type = nullptr}, + {.str = "bucket[16]", .source_type = std::make_shared()}, + {.str = "truncate[32]", .source_type = std::make_shared()}}; + + for (const auto& c : cases) { + auto result = TransformFromString(c.str); + ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str; + + const auto& transform = result.value(); + auto transformPtr = transform->Bind(c.source_type); + + auto result_type = transformPtr.value()->ResultType(); + ASSERT_THAT(result_type, IsError(ErrorKind::kNotSupported)); + } +} + } // namespace iceberg From f820aa277ac6b98506861f1c3d091735d51a3d59 Mon Sep 17 00:00:00 2001 From: Junwang Zhao Date: Mon, 7 Jul 2025 21:30:34 +0800 Subject: [PATCH 2/4] use type factory method --- src/iceberg/transform_function.cc | 10 +++---- test/transform_test.cc | 44 +++++++++++++++---------------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/iceberg/transform_function.cc b/src/iceberg/transform_function.cc index 32b1a1f85..55b79dab6 100644 --- a/src/iceberg/transform_function.cc +++ b/src/iceberg/transform_function.cc @@ -64,7 +64,7 @@ Result> BucketTransform::ResultType() const { case TypeId::kUuid: case TypeId::kFixed: case TypeId::kBinary: - return std::make_shared(); + return iceberg::int32(); default: return NotSupported("{} is not a valid input type for bucket transform", src_type->ToString()); @@ -113,7 +113,7 @@ Result> YearTransform::ResultType() const { case TypeId::kDate: case TypeId::kTimestamp: case TypeId::kTimestampTz: - return std::make_shared(); + return iceberg::int32(); default: return NotSupported("{} is not a valid input type for year transform", src_type->ToString()); @@ -136,7 +136,7 @@ Result> MonthTransform::ResultType() const { case TypeId::kDate: case TypeId::kTimestamp: case TypeId::kTimestampTz: - return std::make_shared(); + return iceberg::int32(); default: return NotSupported("{} is not a valid input type for month transform", src_type->ToString()); @@ -159,7 +159,7 @@ Result> DayTransform::ResultType() const { case TypeId::kDate: case TypeId::kTimestamp: case TypeId::kTimestampTz: - return std::make_shared(); + return iceberg::date(); default: return NotSupported("{} is not a valid input type for day transform", src_type->ToString()); @@ -181,7 +181,7 @@ Result> HourTransform::ResultType() const { switch (src_type->type_id()) { case TypeId::kTimestamp: case TypeId::kTimestampTz: - return std::make_shared(); + return iceberg::int32(); default: return NotSupported("{} is not a valid input type for hour transform", src_type->ToString()); diff --git a/test/transform_test.cc b/test/transform_test.cc index a1de10b0e..3d0aeef33 100644 --- a/test/transform_test.cc +++ b/test/transform_test.cc @@ -126,29 +126,29 @@ TEST(TransformResultTypeTest, PositiveCases) { const std::vector cases = { {.str = "identity", - .source_type = std::make_shared(), - .expected_result_type = std::make_shared()}, + .source_type = iceberg::string(), + .expected_result_type = iceberg::string()}, {.str = "year", - .source_type = std::make_shared(), - .expected_result_type = std::make_shared()}, + .source_type = iceberg::timestamp(), + .expected_result_type = iceberg::int32()}, {.str = "month", - .source_type = std::make_shared(), - .expected_result_type = std::make_shared()}, + .source_type = iceberg::timestamp(), + .expected_result_type = iceberg::int32()}, {.str = "day", - .source_type = std::make_shared(), - .expected_result_type = std::make_shared()}, + .source_type = iceberg::timestamp(), + .expected_result_type = iceberg::date()}, {.str = "hour", - .source_type = std::make_shared(), - .expected_result_type = std::make_shared()}, + .source_type = iceberg::timestamp(), + .expected_result_type = iceberg::int32()}, {.str = "void", - .source_type = std::make_shared(), - .expected_result_type = std::make_shared()}, + .source_type = iceberg::string(), + .expected_result_type = iceberg::string()}, {.str = "bucket[16]", - .source_type = std::make_shared(), - .expected_result_type = std::make_shared()}, + .source_type = iceberg::string(), + .expected_result_type = iceberg::int32()}, {.str = "truncate[32]", - .source_type = std::make_shared(), - .expected_result_type = std::make_shared()}, + .source_type = iceberg::string(), + .expected_result_type = iceberg::string()}, }; for (const auto& c : cases) { @@ -174,13 +174,13 @@ TEST(TransformResultTypeTest, NegativeCases) { const std::vector cases = { {.str = "identity", .source_type = nullptr}, - {.str = "year", .source_type = std::make_shared()}, - {.str = "month", .source_type = std::make_shared()}, - {.str = "day", .source_type = std::make_shared()}, - {.str = "hour", .source_type = std::make_shared()}, + {.str = "year", .source_type = iceberg::string()}, + {.str = "month", .source_type = iceberg::string()}, + {.str = "day", .source_type = iceberg::string()}, + {.str = "hour", .source_type = iceberg::string()}, {.str = "void", .source_type = nullptr}, - {.str = "bucket[16]", .source_type = std::make_shared()}, - {.str = "truncate[32]", .source_type = std::make_shared()}}; + {.str = "bucket[16]", .source_type = iceberg::float32()}, + {.str = "truncate[32]", .source_type = iceberg::float64()}}; for (const auto& c : cases) { auto result = TransformFromString(c.str); From 41c92a64f6190662a7f7c3fb20884c7700ebce6e Mon Sep 17 00:00:00 2001 From: Junwang Zhao Date: Tue, 22 Jul 2025 00:15:51 +0800 Subject: [PATCH 3/4] add XYZTransform::Make --- src/iceberg/transform.cc | 16 +++--- src/iceberg/transform_function.cc | 88 ++++++++++++++++++++++--------- src/iceberg/transform_function.h | 50 ++++++++++++++++++ test/transform_test.cc | 22 +++++++- 4 files changed, 140 insertions(+), 36 deletions(-) diff --git a/src/iceberg/transform.cc b/src/iceberg/transform.cc index dc6f3dc88..7898fc638 100644 --- a/src/iceberg/transform.cc +++ b/src/iceberg/transform.cc @@ -115,11 +115,11 @@ Result> Transform::Bind( switch (transform_type_) { case TransformType::kIdentity: - return std::make_unique(source_type); + return IdentityTransform::Make(source_type); case TransformType::kBucket: { if (auto param = std::get_if(¶m_)) { - return std::make_unique(source_type, *param); + return BucketTransform::Make(source_type, *param); } return InvalidArgument("Bucket requires int32 param, none found in transform '{}'", type_str); @@ -127,22 +127,22 @@ Result> Transform::Bind( case TransformType::kTruncate: { if (auto param = std::get_if(¶m_)) { - return std::make_unique(source_type, *param); + return TruncateTransform::Make(source_type, *param); } return InvalidArgument( "Truncate requires int32 param, none found in transform '{}'", type_str); } case TransformType::kYear: - return std::make_unique(source_type); + return YearTransform::Make(source_type); case TransformType::kMonth: - return std::make_unique(source_type); + return MonthTransform::Make(source_type); case TransformType::kDay: - return std::make_unique(source_type); + return DayTransform::Make(source_type); case TransformType::kHour: - return std::make_unique(source_type); + return HourTransform::Make(source_type); case TransformType::kVoid: - return std::make_unique(source_type); + return VoidTransform::Make(source_type); default: return NotSupported("Unsupported transform type: '{}'", type_str); diff --git a/src/iceberg/transform_function.cc b/src/iceberg/transform_function.cc index 55b79dab6..47bd506b6 100644 --- a/src/iceberg/transform_function.cc +++ b/src/iceberg/transform_function.cc @@ -31,12 +31,16 @@ Result IdentityTransform::Transform(const ArrowArray& input) { } Result> IdentityTransform::ResultType() const { - auto src_type = source_type(); - if (!src_type || !src_type->is_primitive()) { + return source_type(); +} + +Result> IdentityTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type || !source_type->is_primitive()) { return NotSupported("{} is not a valid input type for identity transform", - src_type ? src_type->ToString() : "null"); + source_type ? source_type->ToString() : "null"); } - return src_type; + return std::make_unique(source_type); } BucketTransform::BucketTransform(std::shared_ptr const& source_type, @@ -49,9 +53,6 @@ Result BucketTransform::Transform(const ArrowArray& input) { Result> BucketTransform::ResultType() const { auto src_type = source_type(); - if (!src_type) { - return NotSupported("null is not a valid input type for bucket transform"); - } switch (src_type->type_id()) { case TypeId::kInt: case TypeId::kLong: @@ -71,6 +72,14 @@ Result> BucketTransform::ResultType() const { } } +Result> BucketTransform::Make( + std::shared_ptr const& source_type, int32_t num_buckets) { + if (!source_type) { + return NotSupported("null is not a valid input type for bucket transform"); + } + return std::make_unique(source_type, num_buckets); +} + TruncateTransform::TruncateTransform(std::shared_ptr const& source_type, int32_t width) : TransformFunction(TransformType::kTruncate, source_type), width_(width) {} @@ -81,9 +90,6 @@ Result TruncateTransform::Transform(const ArrowArray& input) { Result> TruncateTransform::ResultType() const { auto src_type = source_type(); - if (!src_type) { - return NotSupported("null is not a valid input type for truncate transform"); - } switch (src_type->type_id()) { case TypeId::kInt: case TypeId::kLong: @@ -97,6 +103,14 @@ Result> TruncateTransform::ResultType() const { } } +Result> TruncateTransform::Make( + std::shared_ptr const& source_type, int32_t width) { + if (!source_type) { + return NotSupported("null is not a valid input type for truncate transform"); + } + return std::make_unique(source_type, width); +} + YearTransform::YearTransform(std::shared_ptr const& source_type) : TransformFunction(TransformType::kTruncate, source_type) {} @@ -106,9 +120,6 @@ Result YearTransform::Transform(const ArrowArray& input) { Result> YearTransform::ResultType() const { auto src_type = source_type(); - if (!src_type) { - return NotSupported("null is not a valid input type for year transform"); - } switch (src_type->type_id()) { case TypeId::kDate: case TypeId::kTimestamp: @@ -120,6 +131,14 @@ Result> YearTransform::ResultType() const { } } +Result> YearTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type) { + return NotSupported("null is not a valid input type for year transform"); + } + return std::make_unique(source_type); +} + MonthTransform::MonthTransform(std::shared_ptr const& source_type) : TransformFunction(TransformType::kMonth, source_type) {} @@ -129,9 +148,6 @@ Result MonthTransform::Transform(const ArrowArray& input) { Result> MonthTransform::ResultType() const { auto src_type = source_type(); - if (!src_type) { - return NotSupported("null is not a valid input type for month transform"); - } switch (src_type->type_id()) { case TypeId::kDate: case TypeId::kTimestamp: @@ -143,6 +159,14 @@ Result> MonthTransform::ResultType() const { } } +Result> MonthTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type) { + return NotSupported("null is not a valid input type for month transform"); + } + return std::make_unique(source_type); +} + DayTransform::DayTransform(std::shared_ptr const& source_type) : TransformFunction(TransformType::kDay, source_type) {} @@ -152,9 +176,6 @@ Result DayTransform::Transform(const ArrowArray& input) { Result> DayTransform::ResultType() const { auto src_type = source_type(); - if (!src_type) { - return NotSupported("null is not a valid input type for day transform"); - } switch (src_type->type_id()) { case TypeId::kDate: case TypeId::kTimestamp: @@ -166,6 +187,14 @@ Result> DayTransform::ResultType() const { } } +Result> DayTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type) { + return NotSupported("null is not a valid input type for day transform"); + } + return std::make_unique(source_type); +} + HourTransform::HourTransform(std::shared_ptr const& source_type) : TransformFunction(TransformType::kHour, source_type) {} @@ -175,9 +204,6 @@ Result HourTransform::Transform(const ArrowArray& input) { Result> HourTransform::ResultType() const { auto src_type = source_type(); - if (!src_type) { - return NotSupported("null is not a valid input type for hour transform"); - } switch (src_type->type_id()) { case TypeId::kTimestamp: case TypeId::kTimestampTz: @@ -188,6 +214,14 @@ Result> HourTransform::ResultType() const { } } +Result> HourTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type) { + return NotSupported("null is not a valid input type for hour transform"); + } + return std::make_unique(source_type); +} + VoidTransform::VoidTransform(std::shared_ptr const& source_type) : TransformFunction(TransformType::kVoid, source_type) {} @@ -195,12 +229,14 @@ Result VoidTransform::Transform(const ArrowArray& input) { return NotImplemented("VoidTransform::Transform"); } -Result> VoidTransform::ResultType() const { - auto src_type = source_type(); - if (!src_type) { +Result> VoidTransform::ResultType() const { return source_type(); } + +Result> VoidTransform::Make( + std::shared_ptr const& source_type) { + if (!source_type) { return NotSupported("null is not a valid input type for void transform"); } - return src_type; + return std::make_unique(source_type); } } // namespace iceberg diff --git a/src/iceberg/transform_function.h b/src/iceberg/transform_function.h index eb844324c..7fffd61f0 100644 --- a/src/iceberg/transform_function.h +++ b/src/iceberg/transform_function.h @@ -35,6 +35,12 @@ class IdentityTransform : public TransformFunction { /// \brief Returns the same type as the source type if it is valid. Result> ResultType() const override; + + /// \brief Create an IdentityTransform. + /// \param source_type Type of the input data. + /// \return A Result containing the IdentityTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; /// \brief Bucket transform that hashes input values into N buckets. @@ -50,6 +56,13 @@ class BucketTransform : public TransformFunction { /// \brief Returns INT32 as the output type. Result> ResultType() const override; + /// \brief Create a BucketTransform. + /// \param source_type Type of the input data. + /// \param num_buckets Number of buckets to hash into. + /// \return A Result containing the BucketTransform or an error. + static Result> Make( + std::shared_ptr const& source_type, int32_t num_buckets); + private: int32_t num_buckets_; }; @@ -67,6 +80,13 @@ class TruncateTransform : public TransformFunction { /// \brief Returns the same type as source_type. Result> ResultType() const override; + /// \brief Create a TruncateTransform. + /// \param source_type Type of the input data. + /// \param width The width to truncate to. + /// \return A Result containing the TruncateTransform or an error. + static Result> Make( + std::shared_ptr const& source_type, int32_t width); + private: int32_t width_; }; @@ -82,6 +102,12 @@ class YearTransform : public TransformFunction { /// \brief Returns INT32 as the output type. Result> ResultType() const override; + + /// \brief Create a YearTransform. + /// \param source_type Type of the input data. + /// \return A Result containing the YearTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; /// \brief Month transform that extracts the month component from timestamp inputs. @@ -95,6 +121,12 @@ class MonthTransform : public TransformFunction { /// \brief Returns INT32 as the output type. Result> ResultType() const override; + + /// \brief Create a MonthTransform. + /// \param source_type Type of the input data. + /// \return A Result containing the MonthTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; /// \brief Day transform that extracts the day of the month from timestamp inputs. @@ -108,6 +140,12 @@ class DayTransform : public TransformFunction { /// \brief Returns INT32 as the output type. Result> ResultType() const override; + + /// \brief Create a DayTransform. + /// \param source_type Type of the input data. + /// \return A Result containing the DayTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; /// \brief Hour transform that extracts the hour component from timestamp inputs. @@ -121,6 +159,12 @@ class HourTransform : public TransformFunction { /// \brief Returns INT32 as the output type. Result> ResultType() const override; + + /// \brief Create a HourTransform. + /// \param source_type Type of the input data. + /// \return A Result containing the HourTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; /// \brief Void transform that discards the input and always returns null. @@ -134,6 +178,12 @@ class VoidTransform : public TransformFunction { /// \brief Returns null type or a sentinel type indicating void. Result> ResultType() const override; + + /// \brief Create a VoidTransform. + /// \param source_type Input type (ignored). + /// \return A Result containing the VoidTransform or an error. + static Result> Make( + std::shared_ptr const& source_type); }; } // namespace iceberg diff --git a/test/transform_test.cc b/test/transform_test.cc index 3d0aeef33..0f35b4543 100644 --- a/test/transform_test.cc +++ b/test/transform_test.cc @@ -173,15 +173,22 @@ TEST(TransformResultTypeTest, NegativeCases) { }; const std::vector cases = { - {.str = "identity", .source_type = nullptr}, {.str = "year", .source_type = iceberg::string()}, {.str = "month", .source_type = iceberg::string()}, {.str = "day", .source_type = iceberg::string()}, {.str = "hour", .source_type = iceberg::string()}, - {.str = "void", .source_type = nullptr}, {.str = "bucket[16]", .source_type = iceberg::float32()}, {.str = "truncate[32]", .source_type = iceberg::float64()}}; + const std::vector null_cases = {{.str = "identity", .source_type = nullptr}, + {.str = "year", .source_type = nullptr}, + {.str = "month", .source_type = nullptr}, + {.str = "day", .source_type = nullptr}, + {.str = "hour", .source_type = nullptr}, + {.str = "void", .source_type = nullptr}, + {.str = "bucket[16]", .source_type = nullptr}, + {.str = "truncate[32]", .source_type = nullptr}}; + for (const auto& c : cases) { auto result = TransformFromString(c.str); ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str; @@ -192,6 +199,17 @@ TEST(TransformResultTypeTest, NegativeCases) { auto result_type = transformPtr.value()->ResultType(); ASSERT_THAT(result_type, IsError(ErrorKind::kNotSupported)); } + + for (const auto& c : null_cases) { + auto result = TransformFromString(c.str); + ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str; + + const auto& transform = result.value(); + auto transformPtr = transform->Bind(c.source_type); + + ASSERT_THAT(transformPtr, IsError(ErrorKind::kNotSupported)); + EXPECT_THAT(transformPtr, HasErrorMessage("null is not a valid")); + } } } // namespace iceberg From bf949711996c66fbe0fe7e3473c66c81dcb68b32 Mon Sep 17 00:00:00 2001 From: Junwang Zhao Date: Tue, 22 Jul 2025 21:47:08 +0800 Subject: [PATCH 4/4] validate source_type and other params in Make function --- src/iceberg/transform_function.cc | 126 +++++++++++++++--------------- test/transform_test.cc | 23 +----- 2 files changed, 67 insertions(+), 82 deletions(-) diff --git a/src/iceberg/transform_function.cc b/src/iceberg/transform_function.cc index 47bd506b6..9ddf6e9f7 100644 --- a/src/iceberg/transform_function.cc +++ b/src/iceberg/transform_function.cc @@ -52,8 +52,15 @@ Result BucketTransform::Transform(const ArrowArray& input) { } Result> BucketTransform::ResultType() const { - auto src_type = source_type(); - switch (src_type->type_id()) { + return iceberg::int32(); +} + +Result> BucketTransform::Make( + std::shared_ptr const& source_type, int32_t num_buckets) { + if (!source_type) { + return NotSupported("null is not a valid input type for bucket transform"); + } + switch (source_type->type_id()) { case TypeId::kInt: case TypeId::kLong: case TypeId::kDecimal: @@ -65,17 +72,13 @@ Result> BucketTransform::ResultType() const { case TypeId::kUuid: case TypeId::kFixed: case TypeId::kBinary: - return iceberg::int32(); + break; default: return NotSupported("{} is not a valid input type for bucket transform", - src_type->ToString()); + source_type->ToString()); } -} - -Result> BucketTransform::Make( - std::shared_ptr const& source_type, int32_t num_buckets) { - if (!source_type) { - return NotSupported("null is not a valid input type for bucket transform"); + if (num_buckets <= 0) { + return InvalidArgument("Number of buckets must be positive, got {}", num_buckets); } return std::make_unique(source_type, num_buckets); } @@ -89,24 +92,27 @@ Result TruncateTransform::Transform(const ArrowArray& input) { } Result> TruncateTransform::ResultType() const { - auto src_type = source_type(); - switch (src_type->type_id()) { + return source_type(); +} + +Result> TruncateTransform::Make( + std::shared_ptr const& source_type, int32_t width) { + if (!source_type) { + return NotSupported("null is not a valid input type for truncate transform"); + } + switch (source_type->type_id()) { case TypeId::kInt: case TypeId::kLong: case TypeId::kDecimal: case TypeId::kString: case TypeId::kBinary: - return src_type; + break; default: return NotSupported("{} is not a valid input type for truncate transform", - src_type->ToString()); + source_type->ToString()); } -} - -Result> TruncateTransform::Make( - std::shared_ptr const& source_type, int32_t width) { - if (!source_type) { - return NotSupported("null is not a valid input type for truncate transform"); + if (width <= 0) { + return InvalidArgument("Width must be positive, got {}", width); } return std::make_unique(source_type, width); } @@ -119,16 +125,7 @@ Result YearTransform::Transform(const ArrowArray& input) { } Result> YearTransform::ResultType() const { - auto src_type = source_type(); - switch (src_type->type_id()) { - case TypeId::kDate: - case TypeId::kTimestamp: - case TypeId::kTimestampTz: - return iceberg::int32(); - default: - return NotSupported("{} is not a valid input type for year transform", - src_type->ToString()); - } + return iceberg::int32(); } Result> YearTransform::Make( @@ -136,6 +133,15 @@ Result> YearTransform::Make( if (!source_type) { return NotSupported("null is not a valid input type for year transform"); } + switch (source_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + break; + default: + return NotSupported("{} is not a valid input type for year transform", + source_type->ToString()); + } return std::make_unique(source_type); } @@ -147,16 +153,7 @@ Result MonthTransform::Transform(const ArrowArray& input) { } Result> MonthTransform::ResultType() const { - auto src_type = source_type(); - switch (src_type->type_id()) { - case TypeId::kDate: - case TypeId::kTimestamp: - case TypeId::kTimestampTz: - return iceberg::int32(); - default: - return NotSupported("{} is not a valid input type for month transform", - src_type->ToString()); - } + return iceberg::int32(); } Result> MonthTransform::Make( @@ -164,6 +161,15 @@ Result> MonthTransform::Make( if (!source_type) { return NotSupported("null is not a valid input type for month transform"); } + switch (source_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + break; + default: + return NotSupported("{} is not a valid input type for month transform", + source_type->ToString()); + } return std::make_unique(source_type); } @@ -174,24 +180,22 @@ Result DayTransform::Transform(const ArrowArray& input) { return NotImplemented("DayTransform::Transform"); } -Result> DayTransform::ResultType() const { - auto src_type = source_type(); - switch (src_type->type_id()) { - case TypeId::kDate: - case TypeId::kTimestamp: - case TypeId::kTimestampTz: - return iceberg::date(); - default: - return NotSupported("{} is not a valid input type for day transform", - src_type->ToString()); - } -} +Result> DayTransform::ResultType() const { return iceberg::date(); } Result> DayTransform::Make( std::shared_ptr const& source_type) { if (!source_type) { return NotSupported("null is not a valid input type for day transform"); } + switch (source_type->type_id()) { + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + break; + default: + return NotSupported("{} is not a valid input type for day transform", + source_type->ToString()); + } return std::make_unique(source_type); } @@ -203,15 +207,7 @@ Result HourTransform::Transform(const ArrowArray& input) { } Result> HourTransform::ResultType() const { - auto src_type = source_type(); - switch (src_type->type_id()) { - case TypeId::kTimestamp: - case TypeId::kTimestampTz: - return iceberg::int32(); - default: - return NotSupported("{} is not a valid input type for hour transform", - src_type->ToString()); - } + return iceberg::int32(); } Result> HourTransform::Make( @@ -219,6 +215,14 @@ Result> HourTransform::Make( if (!source_type) { return NotSupported("null is not a valid input type for hour transform"); } + switch (source_type->type_id()) { + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + break; + default: + return NotSupported("{} is not a valid input type for hour transform", + source_type->ToString()); + } return std::make_unique(source_type); } diff --git a/test/transform_test.cc b/test/transform_test.cc index 0f35b4543..33149d14d 100644 --- a/test/transform_test.cc +++ b/test/transform_test.cc @@ -173,22 +173,15 @@ TEST(TransformResultTypeTest, NegativeCases) { }; const std::vector cases = { + {.str = "identity", .source_type = nullptr}, {.str = "year", .source_type = iceberg::string()}, {.str = "month", .source_type = iceberg::string()}, {.str = "day", .source_type = iceberg::string()}, {.str = "hour", .source_type = iceberg::string()}, + {.str = "void", .source_type = nullptr}, {.str = "bucket[16]", .source_type = iceberg::float32()}, {.str = "truncate[32]", .source_type = iceberg::float64()}}; - const std::vector null_cases = {{.str = "identity", .source_type = nullptr}, - {.str = "year", .source_type = nullptr}, - {.str = "month", .source_type = nullptr}, - {.str = "day", .source_type = nullptr}, - {.str = "hour", .source_type = nullptr}, - {.str = "void", .source_type = nullptr}, - {.str = "bucket[16]", .source_type = nullptr}, - {.str = "truncate[32]", .source_type = nullptr}}; - for (const auto& c : cases) { auto result = TransformFromString(c.str); ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str; @@ -196,19 +189,7 @@ TEST(TransformResultTypeTest, NegativeCases) { const auto& transform = result.value(); auto transformPtr = transform->Bind(c.source_type); - auto result_type = transformPtr.value()->ResultType(); - ASSERT_THAT(result_type, IsError(ErrorKind::kNotSupported)); - } - - for (const auto& c : null_cases) { - auto result = TransformFromString(c.str); - ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str; - - const auto& transform = result.value(); - auto transformPtr = transform->Bind(c.source_type); - ASSERT_THAT(transformPtr, IsError(ErrorKind::kNotSupported)); - EXPECT_THAT(transformPtr, HasErrorMessage("null is not a valid")); } }