diff --git a/src/iceberg/CMakeLists.txt b/src/iceberg/CMakeLists.txt index 369666b75..275d71fce 100644 --- a/src/iceberg/CMakeLists.txt +++ b/src/iceberg/CMakeLists.txt @@ -29,6 +29,7 @@ set(ICEBERG_SOURCES expression/literal.cc expression/predicate.cc expression/rewrite_not.cc + expression/strict_metrics_evaluator.cc expression/term.cc file_reader.cc file_writer.cc diff --git a/src/iceberg/catalog/rest/meson.build b/src/iceberg/catalog/rest/meson.build index 89a68850e..8378b2a8c 100644 --- a/src/iceberg/catalog/rest/meson.build +++ b/src/iceberg/catalog/rest/meson.build @@ -61,7 +61,6 @@ install_headers( 'error_handlers.h', 'http_client.h', 'iceberg_rest_export.h', - 'json_internal.h', 'resource_paths.h', 'rest_catalog.h', 'rest_util.h', diff --git a/src/iceberg/expression/aggregate.cc b/src/iceberg/expression/aggregate.cc index a9c1a60bf..0c5352d79 100644 --- a/src/iceberg/expression/aggregate.cc +++ b/src/iceberg/expression/aggregate.cc @@ -19,11 +19,15 @@ #include "iceberg/expression/aggregate.h" +#include #include +#include #include +#include #include #include "iceberg/expression/literal.h" +#include "iceberg/manifest/manifest_entry.h" #include "iceberg/row/struct_like.h" #include "iceberg/type.h" #include "iceberg/util/checked_cast.h" @@ -38,6 +42,19 @@ std::shared_ptr GetPrimitiveType(const BoundTerm& term) { return internal::checked_pointer_cast(term.type()); } +Result EvaluateBoundTerm(const BoundTerm& term, + const std::optional>& bound) { + auto ptype = GetPrimitiveType(term); + if (!bound.has_value()) { + SingleValueStructLike data(Literal::Null(ptype)); + return term.Evaluate(data); + } + + ICEBERG_ASSIGN_OR_RAISE(auto literal, Literal::Deserialize(*bound, ptype)); + SingleValueStructLike data(std::move(literal)); + return term.Evaluate(data); +} + class CountAggregator : public BoundAggregate::Aggregator { public: explicit CountAggregator(const CountAggregate& aggregate) : aggregate_(aggregate) {} @@ -48,11 +65,32 @@ class CountAggregator : public BoundAggregate::Aggregator { return {}; } - Literal GetResult() const override { return Literal::Long(count_); } + Status Update(const DataFile& file) override { + if (!valid_) { + return {}; + } + if (!aggregate_.HasValue(file)) { + valid_ = false; + return {}; + } + ICEBERG_ASSIGN_OR_RAISE(auto count, aggregate_.CountFor(file)); + count_ += count; + return {}; + } + + Literal GetResult() const override { + if (!valid_) { + return Literal::Null(int64()); + } + return Literal::Long(count_); + } + + bool IsValid() const override { return valid_; } private: const CountAggregate& aggregate_; int64_t count_ = 0; + bool valid_ = true; }; class MaxAggregator : public BoundAggregate::Aggregator { @@ -82,11 +120,47 @@ class MaxAggregator : public BoundAggregate::Aggregator { return {}; } - Literal GetResult() const override { return current_; } + Status Update(const DataFile& file) override { + if (!valid_) { + return {}; + } + if (!aggregate_.HasValue(file)) { + valid_ = false; + return {}; + } + + ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file)); + if (value.IsNull()) { + return {}; + } + if (current_.IsNull()) { + current_ = std::move(value); + return {}; + } + + if (auto ordering = value <=> current_; + ordering == std::partial_ordering::unordered) { + return InvalidArgument("Cannot compare literal {} with current value {}", + value.ToString(), current_.ToString()); + } else if (ordering == std::partial_ordering::greater) { + current_ = std::move(value); + } + return {}; + } + + Literal GetResult() const override { + if (!valid_) { + return Literal::Null(GetPrimitiveType(*aggregate_.term())); + } + return current_; + } + + bool IsValid() const override { return valid_; } private: const MaxAggregate& aggregate_; Literal current_; + bool valid_ = true; }; class MinAggregator : public BoundAggregate::Aggregator { @@ -115,13 +189,65 @@ class MinAggregator : public BoundAggregate::Aggregator { return {}; } - Literal GetResult() const override { return current_; } + Status Update(const DataFile& file) override { + if (!valid_) { + return {}; + } + if (!aggregate_.HasValue(file)) { + valid_ = false; + return {}; + } + + ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file)); + if (value.IsNull()) { + return {}; + } + if (current_.IsNull()) { + current_ = std::move(value); + return {}; + } + + if (auto ordering = value <=> current_; + ordering == std::partial_ordering::unordered) { + return InvalidArgument("Cannot compare literal {} with current value {}", + value.ToString(), current_.ToString()); + } else if (ordering == std::partial_ordering::less) { + current_ = std::move(value); + } + return {}; + } + + Literal GetResult() const override { + if (!valid_) { + return Literal::Null(GetPrimitiveType(*aggregate_.term())); + } + return current_; + } + + bool IsValid() const override { return valid_; } private: const MinAggregate& aggregate_; Literal current_; + bool valid_ = true; }; +template +std::optional GetMapValue(const std::map& map, int32_t key) { + auto iter = map.find(key); + if (iter == map.end()) { + return std::nullopt; + } + return iter->second; +} + +int32_t GetFieldId(const std::shared_ptr& term) { + ICEBERG_DCHECK(term != nullptr, "Aggregate term should not be null"); + auto ref = term->reference(); + ICEBERG_DCHECK(ref != nullptr, "Aggregate term reference should not be null"); + return ref->field().field_id(); +} + } // namespace template @@ -149,7 +275,11 @@ std::string Aggregate::ToString() const { // -------------------- CountAggregate -------------------- Result CountAggregate::Evaluate(const StructLike& data) const { - return CountFor(data).transform([](int64_t count) { return Literal::Long(count); }); + return CountFor(data).transform(Literal::Long); +} + +Result CountAggregate::Evaluate(const DataFile& file) const { + return CountFor(file).transform(Literal::Long); } std::unique_ptr CountAggregate::NewAggregator() const { @@ -173,6 +303,22 @@ Result CountNonNullAggregate::CountFor(const StructLike& data) const { [](const auto& val) { return val.IsNull() ? 0 : 1; }); } +Result CountNonNullAggregate::CountFor(const DataFile& file) const { + auto field_id = GetFieldId(term()); + if (!HasValue(file)) { + return NotFound("Missing metrics for field id {}", field_id); + } + auto value_count = GetMapValue(file.value_counts, field_id).value(); + auto null_count = GetMapValue(file.null_value_counts, field_id).value(); + return value_count - null_count; +} + +bool CountNonNullAggregate::HasValue(const DataFile& file) const { + auto field_id = GetFieldId(term()); + return file.value_counts.contains(field_id) && + file.null_value_counts.contains(field_id); +} + CountNullAggregate::CountNullAggregate(std::shared_ptr term) : CountAggregate(Expression::Operation::kCountNull, std::move(term)) {} @@ -189,6 +335,18 @@ Result CountNullAggregate::CountFor(const StructLike& data) const { [](const auto& val) { return val.IsNull() ? 1 : 0; }); } +Result CountNullAggregate::CountFor(const DataFile& file) const { + auto field_id = GetFieldId(term()); + if (!HasValue(file)) { + return NotFound("Missing metrics for field id {}", field_id); + } + return GetMapValue(file.null_value_counts, field_id).value(); +} + +bool CountNullAggregate::HasValue(const DataFile& file) const { + return file.null_value_counts.contains(GetFieldId(term())); +} + CountStarAggregate::CountStarAggregate() : CountAggregate(Expression::Operation::kCountStar, nullptr) {} @@ -200,6 +358,17 @@ Result CountStarAggregate::CountFor(const StructLike& /*data*/) const { return 1; } +Result CountStarAggregate::CountFor(const DataFile& file) const { + if (!HasValue(file)) { + return NotFound("Record count is missing"); + } + return file.record_count; +} + +bool CountStarAggregate::HasValue(const DataFile& file) const { + return file.record_count >= 0; +} + MaxAggregate::MaxAggregate(std::shared_ptr term) : BoundAggregate(Expression::Operation::kMax, std::move(term)) {} @@ -211,10 +380,26 @@ Result MaxAggregate::Evaluate(const StructLike& data) const { return term()->Evaluate(data); } +Result MaxAggregate::Evaluate(const DataFile& file) const { + auto field_id = GetFieldId(term()); + auto upper = GetMapValue(file.upper_bounds, field_id); + return EvaluateBoundTerm(*term(), upper); +} + std::unique_ptr MaxAggregate::NewAggregator() const { return std::unique_ptr(new MaxAggregator(*this)); } +bool MaxAggregate::HasValue(const DataFile& file) const { + auto field_id = GetFieldId(term()); + bool has_bound = file.upper_bounds.contains(field_id); + auto value_count = GetMapValue(file.value_counts, field_id); + auto null_count = GetMapValue(file.null_value_counts, field_id); + bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() && + null_count.value() == value_count.value(); + return has_bound || all_null; +} + MinAggregate::MinAggregate(std::shared_ptr term) : BoundAggregate(Expression::Operation::kMin, std::move(term)) {} @@ -226,10 +411,26 @@ Result MinAggregate::Evaluate(const StructLike& data) const { return term()->Evaluate(data); } +Result MinAggregate::Evaluate(const DataFile& file) const { + auto field_id = GetFieldId(term()); + auto lower = GetMapValue(file.lower_bounds, field_id); + return EvaluateBoundTerm(*term(), lower); +} + std::unique_ptr MinAggregate::NewAggregator() const { return std::unique_ptr(new MinAggregator(*this)); } +bool MinAggregate::HasValue(const DataFile& file) const { + auto field_id = GetFieldId(term()); + bool has_bound = file.lower_bounds.contains(field_id); + auto value_count = GetMapValue(file.value_counts, field_id); + auto null_count = GetMapValue(file.null_value_counts, field_id); + bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() && + null_count.value() == value_count.value(); + return has_bound || all_null; +} + // -------------------- Unbound binding -------------------- template @@ -275,8 +476,10 @@ Result>> UnboundAggregateImpl::Make( } template class Aggregate>; +template class Aggregate>; template class Aggregate; template class UnboundAggregateImpl; +template class UnboundAggregateImpl; // -------------------- AggregateEvaluator -------------------- @@ -296,6 +499,13 @@ class AggregateEvaluatorImpl : public AggregateEvaluator { return {}; } + Status Update(const DataFile& file) override { + for (auto& aggregator : aggregators_) { + ICEBERG_RETURN_UNEXPECTED(aggregator->Update(file)); + } + return {}; + } + Result> GetResults() const override { results_.clear(); results_.reserve(aggregates_.size()); @@ -315,6 +525,10 @@ class AggregateEvaluatorImpl : public AggregateEvaluator { return all.front(); } + bool AllAggregatorsValid() const override { + return std::ranges::all_of(aggregators_, &BoundAggregate::Aggregator::IsValid); + } + private: std::vector> aggregates_; std::vector> aggregators_; diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h index cde9e4583..3b5909725 100644 --- a/src/iceberg/expression/aggregate.h +++ b/src/iceberg/expression/aggregate.h @@ -109,9 +109,9 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound virtual Status Update(const StructLike& data) = 0; - virtual Status Update(const DataFile& file) { - return NotImplemented("Update(DataFile) not implemented"); - } + virtual Status Update(const DataFile& file) = 0; + + virtual bool IsValid() const { return true; } /// \brief Get the result of the aggregation. /// \return The result of the aggregation. @@ -127,6 +127,10 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound } Result Evaluate(const StructLike& data) const override = 0; + virtual Result Evaluate(const DataFile& file) const = 0; + + /// \brief Whether metrics in the data file are sufficient to evaluate. + virtual bool HasValue(const DataFile& file) const = 0; bool is_bound_aggregate() const override { return true; } @@ -143,11 +147,16 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound class ICEBERG_EXPORT CountAggregate : public BoundAggregate { public: Result Evaluate(const StructLike& data) const final; + Result Evaluate(const DataFile& file) const override; std::unique_ptr NewAggregator() const override; /// \brief Count for a single row. Subclasses implement this. virtual Result CountFor(const StructLike& data) const = 0; + /// \brief Count using metrics from a data file. + virtual Result CountFor(const DataFile& file) const = 0; + + bool HasValue(const DataFile& file) const override = 0; protected: CountAggregate(Expression::Operation op, std::shared_ptr term) @@ -161,6 +170,9 @@ class ICEBERG_EXPORT CountNonNullAggregate : public CountAggregate { std::shared_ptr term); Result CountFor(const StructLike& data) const override; + Result CountFor(const DataFile& file) const override; + + bool HasValue(const DataFile& file) const override; private: explicit CountNonNullAggregate(std::shared_ptr term); @@ -173,6 +185,9 @@ class ICEBERG_EXPORT CountNullAggregate : public CountAggregate { std::shared_ptr term); Result CountFor(const StructLike& data) const override; + Result CountFor(const DataFile& file) const override; + + bool HasValue(const DataFile& file) const override; private: explicit CountNullAggregate(std::shared_ptr term); @@ -184,6 +199,9 @@ class ICEBERG_EXPORT CountStarAggregate : public CountAggregate { static Result> Make(); Result CountFor(const StructLike& data) const override; + Result CountFor(const DataFile& file) const override; + + bool HasValue(const DataFile& file) const override; private: CountStarAggregate(); @@ -195,9 +213,12 @@ class ICEBERG_EXPORT MaxAggregate : public BoundAggregate { static std::shared_ptr Make(std::shared_ptr term); Result Evaluate(const StructLike& data) const override; + Result Evaluate(const DataFile& file) const final; std::unique_ptr NewAggregator() const override; + bool HasValue(const DataFile& file) const override; + private: explicit MaxAggregate(std::shared_ptr term); }; @@ -208,9 +229,12 @@ class ICEBERG_EXPORT MinAggregate : public BoundAggregate { static std::shared_ptr Make(std::shared_ptr term); Result Evaluate(const StructLike& data) const override; + Result Evaluate(const DataFile& file) const final; std::unique_ptr NewAggregator() const override; + bool HasValue(const DataFile& file) const override; + private: explicit MinAggregate(std::shared_ptr term); }; @@ -234,11 +258,17 @@ class ICEBERG_EXPORT AggregateEvaluator { /// \brief Update aggregates with a row. virtual Status Update(const StructLike& data) = 0; + /// \brief Update aggregates using data file metrics. + virtual Status Update(const DataFile& file) = 0; + /// \brief Final aggregated value. virtual Result> GetResults() const = 0; /// \brief Convenience accessor when only one aggregate is evaluated. virtual Result GetResult() const = 0; + + /// \brief Whether all aggregators are still valid (metrics present). + virtual bool AllAggregatorsValid() const = 0; }; } // namespace iceberg diff --git a/src/iceberg/expression/expressions.cc b/src/iceberg/expression/expressions.cc index 7eef60232..4b0e538ae 100644 --- a/src/iceberg/expression/expressions.cc +++ b/src/iceberg/expression/expressions.cc @@ -138,6 +138,13 @@ std::shared_ptr> Expressions::Max( return agg; } +std::shared_ptr> Expressions::Max( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kMax, std::move(expr))); + return agg; +} + std::shared_ptr> Expressions::Min(std::string name) { return Min(Ref(std::move(name))); } @@ -149,6 +156,13 @@ std::shared_ptr> Expressions::Min( return agg; } +std::shared_ptr> Expressions::Min( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kMin, std::move(expr))); + return agg; +} + // Template implementations for unary predicates std::shared_ptr> Expressions::IsNull( diff --git a/src/iceberg/expression/expressions.h b/src/iceberg/expression/expressions.h index 92c523ca7..4ef2e7800 100644 --- a/src/iceberg/expression/expressions.h +++ b/src/iceberg/expression/expressions.h @@ -135,6 +135,9 @@ class ICEBERG_EXPORT Expressions { /// \brief Create a MAX aggregate for an unbound term. static std::shared_ptr> Max( std::shared_ptr> expr); + /// \brief Create a MAX aggregate for an unbound transform term. + static std::shared_ptr> Max( + std::shared_ptr> expr); /// \brief Create a MIN aggregate for a field name. static std::shared_ptr> Min(std::string name); @@ -142,6 +145,9 @@ class ICEBERG_EXPORT Expressions { /// \brief Create a MIN aggregate for an unbound term. static std::shared_ptr> Min( std::shared_ptr> expr); + /// \brief Create a MIN aggregate for an unbound transform term. + static std::shared_ptr> Min( + std::shared_ptr> expr); // Unary predicates diff --git a/src/iceberg/expression/meson.build b/src/iceberg/expression/meson.build index 830059087..8e312791b 100644 --- a/src/iceberg/expression/meson.build +++ b/src/iceberg/expression/meson.build @@ -17,13 +17,17 @@ install_headers( [ + 'aggregate.h', 'binder.h', + 'evaluator.h', 'expression.h', 'expression_visitor.h', 'expressions.h', + 'inclusive_metrics_evaluator.h', 'literal.h', 'predicate.h', 'rewrite_not.h', + 'strict_metrics_evaluator.h', 'term.h', ], subdir: 'iceberg/expression', diff --git a/src/iceberg/expression/strict_metrics_evaluator.cc b/src/iceberg/expression/strict_metrics_evaluator.cc new file mode 100644 index 000000000..e2fe34f14 --- /dev/null +++ b/src/iceberg/expression/strict_metrics_evaluator.cc @@ -0,0 +1,506 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/expression/strict_metrics_evaluator.h" + +#include "iceberg/expression/binder.h" +#include "iceberg/expression/expression_visitor.h" +#include "iceberg/expression/rewrite_not.h" +#include "iceberg/expression/term.h" +#include "iceberg/manifest/manifest_entry.h" +#include "iceberg/schema.h" +#include "iceberg/type.h" +#include "iceberg/util/macros.h" + +namespace iceberg { + +namespace { +constexpr bool kRowsMustMatch = true; +constexpr bool kRowsMightNotMatch = false; +} // namespace + +// If the term in any expression is not a direct reference, assume that rows may not +// match. This happens when transforms or other expressions are passed to this evaluator. +// For example, bucket16(x) = 0 can't be determined because this visitor operates on data +// metrics and not partition values. It may be possible to un-transform expressions for +// order preserving transforms in the future, but this is not currently supported. +#define RETURN_IF_NOT_REFERENCE(expr) \ + if (auto ref = dynamic_cast(expr.get()); ref == nullptr) { \ + return kRowsMightNotMatch; \ + } + +class StrictMetricsVisitor : public BoundVisitor { + public: + explicit StrictMetricsVisitor(const DataFile& data_file, const Schema& schema) + : data_file_(data_file), schema_(schema) {} + + Result AlwaysTrue() override { return kRowsMustMatch; } + + Result AlwaysFalse() override { return kRowsMightNotMatch; } + + Result Not(bool child_result) override { return !child_result; } + + Result And(bool left_result, bool right_result) override { + return left_result && right_result; + } + + Result Or(bool left_result, bool right_result) override { + return left_result || right_result; + } + + Result IsNull(const std::shared_ptr& expr) override { + RETURN_IF_NOT_REFERENCE(expr); + + // no need to check whether the field is required because binding evaluates that case + // if the column has any non-null values, the expression does not match + int32_t id = expr->reference()->field().field_id(); + + ICEBERG_ASSIGN_OR_RAISE(auto is_nested, IsNestedColumn(id)); + if (is_nested) { + return kRowsMightNotMatch; + } + + if (ContainsNullsOnly(id)) { + return kRowsMustMatch; + } + return kRowsMightNotMatch; + } + + Result NotNull(const std::shared_ptr& expr) override { + RETURN_IF_NOT_REFERENCE(expr); + + // no need to check whether the field is required because binding evaluates that case + // if the column has any null values, the expression does not match + int32_t id = expr->reference()->field().field_id(); + + ICEBERG_ASSIGN_OR_RAISE(auto is_nested, IsNestedColumn(id)); + if (is_nested) { + return kRowsMightNotMatch; + } + + auto it = data_file_.null_value_counts.find(id); + if (it != data_file_.null_value_counts.cend() && it->second == 0) { + return kRowsMustMatch; + } + + return kRowsMightNotMatch; + } + + Result IsNaN(const std::shared_ptr& expr) override { + RETURN_IF_NOT_REFERENCE(expr); + + int32_t id = expr->reference()->field().field_id(); + + if (ContainsNaNsOnly(id)) { + return kRowsMustMatch; + } + + return kRowsMightNotMatch; + } + + Result NotNaN(const std::shared_ptr& expr) override { + RETURN_IF_NOT_REFERENCE(expr); + + int32_t id = expr->reference()->field().field_id(); + + auto it = data_file_.nan_value_counts.find(id); + if (it != data_file_.nan_value_counts.cend() && it->second == 0) { + return kRowsMustMatch; + } + + if (ContainsNullsOnly(id)) { + return kRowsMustMatch; + } + + return kRowsMightNotMatch; + } + + Result Lt(const std::shared_ptr& expr, const Literal& lit) override { + RETURN_IF_NOT_REFERENCE(expr); + + // Rows must match when: <----------Min----Max---X-------> + int32_t id = expr->reference()->field().field_id(); + + ICEBERG_ASSIGN_OR_RAISE(auto is_nested, IsNestedColumn(id)); + if (is_nested) { + return kRowsMightNotMatch; + } + + if (CanContainNulls(id) || CanContainNaNs(id)) { + return kRowsMightNotMatch; + } + + auto it = data_file_.upper_bounds.find(id); + if (it != data_file_.upper_bounds.cend()) { + ICEBERG_ASSIGN_OR_RAISE(auto upper, ParseBound(expr, it->second)); + if (upper < lit) { + return kRowsMustMatch; + } + } + + return kRowsMightNotMatch; + } + + Result LtEq(const std::shared_ptr& expr, const Literal& lit) override { + RETURN_IF_NOT_REFERENCE(expr); + + // Rows must match when: <----------Min----Max---X-------> + int32_t id = expr->reference()->field().field_id(); + + ICEBERG_ASSIGN_OR_RAISE(auto is_nested, IsNestedColumn(id)); + if (is_nested) { + return kRowsMightNotMatch; + } + + if (CanContainNulls(id) || CanContainNaNs(id)) { + return kRowsMightNotMatch; + } + + auto it = data_file_.upper_bounds.find(id); + if (it != data_file_.upper_bounds.cend()) { + ICEBERG_ASSIGN_OR_RAISE(auto upper, ParseBound(expr, it->second)); + if (upper <= lit) { + return kRowsMustMatch; + } + } + + return kRowsMightNotMatch; + } + + Result Gt(const std::shared_ptr& expr, const Literal& lit) override { + RETURN_IF_NOT_REFERENCE(expr); + + // Rows must match when: <-------X---Min----Max----------> + int32_t id = expr->reference()->field().field_id(); + + ICEBERG_ASSIGN_OR_RAISE(auto is_nested, IsNestedColumn(id)); + if (is_nested) { + return kRowsMightNotMatch; + } + + if (CanContainNulls(id) || CanContainNaNs(id)) { + return kRowsMightNotMatch; + } + + auto it = data_file_.lower_bounds.find(id); + if (it != data_file_.lower_bounds.cend()) { + ICEBERG_ASSIGN_OR_RAISE(auto lower, ParseBound(expr, it->second)); + if (lower.IsNaN()) { + // NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for + // more. + return kRowsMightNotMatch; + } + + if (lower > lit) { + return kRowsMustMatch; + } + } + + return kRowsMightNotMatch; + } + + Result GtEq(const std::shared_ptr& expr, const Literal& lit) override { + RETURN_IF_NOT_REFERENCE(expr); + + // Rows must match when: <-------X---Min----Max----------> + int32_t id = expr->reference()->field().field_id(); + + ICEBERG_ASSIGN_OR_RAISE(auto is_nested, IsNestedColumn(id)); + if (is_nested) { + return kRowsMightNotMatch; + } + + if (CanContainNulls(id) || CanContainNaNs(id)) { + return kRowsMightNotMatch; + } + + auto it = data_file_.lower_bounds.find(id); + if (it != data_file_.lower_bounds.cend()) { + ICEBERG_ASSIGN_OR_RAISE(auto lower, ParseBound(expr, it->second)); + if (lower.IsNaN()) { + // NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for + // more. + return kRowsMightNotMatch; + } + + if (lower >= lit) { + return kRowsMustMatch; + } + } + + return kRowsMightNotMatch; + } + + Result Eq(const std::shared_ptr& expr, const Literal& lit) override { + RETURN_IF_NOT_REFERENCE(expr); + + // Rows must match when Min == X == Max + int32_t id = expr->reference()->field().field_id(); + + ICEBERG_ASSIGN_OR_RAISE(auto is_nested, IsNestedColumn(id)); + if (is_nested) { + return kRowsMightNotMatch; + } + + if (CanContainNulls(id) || CanContainNaNs(id)) { + return kRowsMightNotMatch; + } + auto lower_it = data_file_.lower_bounds.find(id); + auto upper_it = data_file_.upper_bounds.find(id); + if (lower_it != data_file_.lower_bounds.cend() && + upper_it != data_file_.upper_bounds.cend()) { + ICEBERG_ASSIGN_OR_RAISE(auto lower, ParseBound(expr, lower_it->second)); + if (lower != lit) { + return kRowsMightNotMatch; + } + ICEBERG_ASSIGN_OR_RAISE(auto upper, ParseBound(expr, upper_it->second)); + if (upper != lit) { + return kRowsMightNotMatch; + } + + return kRowsMustMatch; + } + + return kRowsMightNotMatch; + } + + Result NotEq(const std::shared_ptr& expr, const Literal& lit) override { + RETURN_IF_NOT_REFERENCE(expr); + + // Rows must match when X < Min or Max < X because it is not in the range + int32_t id = expr->reference()->field().field_id(); + + ICEBERG_ASSIGN_OR_RAISE(auto is_nested, IsNestedColumn(id)); + if (is_nested) { + return kRowsMightNotMatch; + } + + if (ContainsNullsOnly(id) || ContainsNaNsOnly(id)) { + return kRowsMustMatch; + } + + auto lower_it = data_file_.lower_bounds.find(id); + if (lower_it != data_file_.lower_bounds.cend()) { + ICEBERG_ASSIGN_OR_RAISE(auto lower, ParseBound(expr, lower_it->second)); + if (lower.IsNaN()) { + // NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for + // more. + return kRowsMightNotMatch; + } + if (lower > lit) { + return kRowsMustMatch; + } + } + + auto upper_it = data_file_.upper_bounds.find(id); + if (upper_it != data_file_.upper_bounds.cend()) { + ICEBERG_ASSIGN_OR_RAISE(auto upper, ParseBound(expr, upper_it->second)); + if (upper < lit) { + return kRowsMustMatch; + } + } + + return kRowsMightNotMatch; + } + + Result In(const std::shared_ptr& expr, + const BoundSetPredicate::LiteralSet& literal_set) override { + RETURN_IF_NOT_REFERENCE(expr); + + int32_t id = expr->reference()->field().field_id(); + + ICEBERG_ASSIGN_OR_RAISE(auto is_nested, IsNestedColumn(id)); + if (is_nested) { + return kRowsMightNotMatch; + } + + if (CanContainNulls(id) || CanContainNaNs(id)) { + return kRowsMightNotMatch; + } + auto lower_it = data_file_.lower_bounds.find(id); + auto upper_it = data_file_.upper_bounds.find(id); + if (lower_it != data_file_.lower_bounds.cend() && + upper_it != data_file_.upper_bounds.cend()) { + // similar to the implementation in eq, first check if the lower bound is in the + // set + ICEBERG_ASSIGN_OR_RAISE(auto lower, ParseBound(expr, lower_it->second)); + if (!literal_set.contains(lower)) { + return kRowsMightNotMatch; + } + // check if the upper bound is in the set + ICEBERG_ASSIGN_OR_RAISE(auto upper, ParseBound(expr, upper_it->second)); + if (!literal_set.contains(upper)) { + return kRowsMightNotMatch; + } + // finally check if the lower bound and the upper bound are equal + if (lower != upper) { + return kRowsMightNotMatch; + } + + // All values must be in the set if the lower bound and the upper bound are in the + // set and are equal. + return kRowsMustMatch; + } + + return kRowsMightNotMatch; + } + + Result NotIn(const std::shared_ptr& expr, + const BoundSetPredicate::LiteralSet& literal_set) override { + RETURN_IF_NOT_REFERENCE(expr); + + int32_t id = expr->reference()->field().field_id(); + + ICEBERG_ASSIGN_OR_RAISE(auto is_nested, IsNestedColumn(id)); + if (is_nested) { + return kRowsMightNotMatch; + } + + if (ContainsNullsOnly(id) || ContainsNaNsOnly(id)) { + return kRowsMustMatch; + } + std::optional lower_bound; + auto lower_it = data_file_.lower_bounds.find(id); + if (lower_it != data_file_.lower_bounds.cend()) { + ICEBERG_ASSIGN_OR_RAISE(auto lower, ParseBound(expr, lower_it->second)); + if (lower.IsNaN()) { + // NaN indicates unreliable bounds. See the StrictMetricsEvaluator docs for + // more. + return kRowsMightNotMatch; + } + lower_bound = std::move(lower); + } + auto literals_view = literal_set | std::views::filter([&](const Literal& lit) { + return lower_bound.has_value() && lower_bound.value() <= lit; + }); + // if all values are less than lower bound, rows must + // match (notIn). + if (lower_bound.has_value() && literals_view.empty()) { + return kRowsMustMatch; + } + + auto upper_it = data_file_.upper_bounds.find(id); + if (upper_it != data_file_.upper_bounds.cend()) { + ICEBERG_ASSIGN_OR_RAISE(auto upper, ParseBound(expr, upper_it->second)); + auto filtered_view = literals_view | std::views::filter([&](const Literal& lit) { + return upper >= lit; + }); + if (filtered_view.empty()) { + // if all remaining values are greater than upper bound, + // rows must match + // (notIn). + return kRowsMustMatch; + } + } + return kRowsMightNotMatch; + } + + Result StartsWith(const std::shared_ptr& expr, + const Literal& lit) override { + return kRowsMightNotMatch; + } + + Result NotStartsWith(const std::shared_ptr& expr, + const Literal& lit) override { + // TODO(xiao.dong) Handle cases that definitely cannot match, + // such as notStartsWith("x") when + // the bounds are ["a", "b"]. + return kRowsMightNotMatch; + } + + private: + Result ParseBound(const std::shared_ptr& expr, + const std::vector& stats) { + auto type = expr->reference()->type(); + if (!type->is_primitive()) { + return NotSupported("Bound of non-primitive type is not supported."); + } + auto primitive_type = internal::checked_pointer_cast(type); + return Literal::Deserialize(stats, primitive_type); + } + + bool CanContainNulls(int32_t id) { + if (data_file_.null_value_counts.empty()) { + return true; + } + auto it = data_file_.null_value_counts.find(id); + return it != data_file_.null_value_counts.cend() && it->second > 0; + } + + bool CanContainNaNs(int32_t id) { + // nan counts might be null for early version writers when nan counters are not + // populated. + auto it = data_file_.nan_value_counts.find(id); + return it != data_file_.nan_value_counts.cend() && it->second > 0; + } + + bool ContainsNullsOnly(int32_t id) { + auto val_it = data_file_.value_counts.find(id); + auto null_it = data_file_.null_value_counts.find(id); + return val_it != data_file_.value_counts.cend() && + null_it != data_file_.null_value_counts.cend() && + val_it->second == null_it->second; + } + + bool ContainsNaNsOnly(int32_t id) { + auto val_it = data_file_.value_counts.find(id); + auto nan_it = data_file_.nan_value_counts.find(id); + return val_it != data_file_.value_counts.cend() && + nan_it != data_file_.nan_value_counts.cend() && + val_it->second == nan_it->second; + } + + Result IsNestedColumn(int32_t id) { + // XXX: null_count might be missing from nested columns but required by + // StrictMetricsEvaluator. + // See https://github.com/apache/iceberg/pull/11261. + ICEBERG_ASSIGN_OR_RAISE(auto field, schema_.GetFieldById(id)); + return !field.has_value() || field->get().type()->is_nested(); + } + + private: + const DataFile& data_file_; + const Schema& schema_; +}; + +StrictMetricsEvaluator::StrictMetricsEvaluator(std::shared_ptr expr, + std::shared_ptr schema) + : expr_(std::move(expr)), schema_(std::move(schema)) {} + +StrictMetricsEvaluator::~StrictMetricsEvaluator() = default; + +Result> StrictMetricsEvaluator::Make( + std::shared_ptr expr, std::shared_ptr schema, + bool case_sensitive) { + ICEBERG_ASSIGN_OR_RAISE(auto rewrite_expr, RewriteNot::Visit(std::move(expr))); + ICEBERG_ASSIGN_OR_RAISE(auto bound_expr, + Binder::Bind(*schema, rewrite_expr, case_sensitive)); + return std::unique_ptr( + new StrictMetricsEvaluator(std::move(bound_expr), std::move(schema))); +} + +Result StrictMetricsEvaluator::Evaluate(const DataFile& data_file) const { + if (data_file.record_count <= 0) { + return kRowsMustMatch; + } + StrictMetricsVisitor visitor(data_file, *schema_); + return Visit(expr_, visitor); +} + +} // namespace iceberg diff --git a/src/iceberg/expression/strict_metrics_evaluator.h b/src/iceberg/expression/strict_metrics_evaluator.h new file mode 100644 index 000000000..60dc74a9c --- /dev/null +++ b/src/iceberg/expression/strict_metrics_evaluator.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +/// \file iceberg/expression/strict_metrics_evaluator.h +/// +/// Evaluates an Expression on a DataFile to test whether all rows in the file match. +/// +/// This evaluation is strict: it returns true if all rows in a file must match the +/// expression. For example, if a file's ts column has min X and max Y, this evaluator +/// will return true for ts < Y+1 but not for ts < Y-1. +/// +/// Files are passed to #eval(ContentFile), which returns true if all rows in the file +/// must contain matching rows and false if the file may contain rows that do not match. +/// +/// Due to the comparison implementation of ORC stats, for float/double columns in ORC +/// files, if the first value in a file is NaN, metrics of this file will report NaN for +/// both upper and lower bound despite that the column could contain non-NaN data. Thus in +/// some scenarios explicitly checks for NaN is necessary in order to not include files +/// that may contain rows that don't match. +/// + +#include + +#include "iceberg/expression/expression.h" +#include "iceberg/iceberg_export.h" +#include "iceberg/result.h" +#include "iceberg/type_fwd.h" + +namespace iceberg { + +/// \brief Evaluates an Expression against DataFile. +/// \note: The evaluator is thread-safe. +class ICEBERG_EXPORT StrictMetricsEvaluator { + public: + /// \brief Make a strict metrics evaluator + /// + /// \param expr The expression to evaluate + /// \param schema The schema of the table + /// \param case_sensitive Whether field name matching is case-sensitive + static Result> Make( + std::shared_ptr expr, std::shared_ptr schema, + bool case_sensitive = true); + + ~StrictMetricsEvaluator(); + + /// \brief Evaluate the expression against a DataFile. + /// + /// \param data_file The data file to evaluate + /// \return true if the file matches the expression, false otherwise, or error + Result Evaluate(const DataFile& data_file) const; + + private: + explicit StrictMetricsEvaluator(std::shared_ptr expr, + std::shared_ptr schema); + + private: + std::shared_ptr expr_; + std::shared_ptr schema_; +}; + +} // namespace iceberg diff --git a/src/iceberg/meson.build b/src/iceberg/meson.build index 5a9933385..c139c66b5 100644 --- a/src/iceberg/meson.build +++ b/src/iceberg/meson.build @@ -51,6 +51,7 @@ iceberg_sources = files( 'expression/literal.cc', 'expression/predicate.cc', 'expression/rewrite_not.cc', + 'expression/strict_metrics_evaluator.cc', 'expression/term.cc', 'file_reader.cc', 'file_writer.cc', diff --git a/src/iceberg/row/struct_like.cc b/src/iceberg/row/struct_like.cc index b0fb67fb4..62db1dcdb 100644 --- a/src/iceberg/row/struct_like.cc +++ b/src/iceberg/row/struct_like.cc @@ -19,7 +19,9 @@ #include "iceberg/row/struct_like.h" +#include #include +#include #include "iceberg/result.h" #include "iceberg/util/checked_cast.h" @@ -28,6 +30,53 @@ namespace iceberg { +Result LiteralToScalar(const Literal& literal) { + if (literal.IsNull()) { + return Scalar{std::monostate{}}; + } + + switch (literal.type()->type_id()) { + case TypeId::kBoolean: + return Scalar{std::get(literal.value())}; + case TypeId::kInt: + case TypeId::kDate: + return Scalar{std::get(literal.value())}; + case TypeId::kLong: + case TypeId::kTime: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + return Scalar{std::get(literal.value())}; + case TypeId::kFloat: + return Scalar{std::get(literal.value())}; + case TypeId::kDouble: + return Scalar{std::get(literal.value())}; + case TypeId::kString: { + const auto& str = std::get(literal.value()); + return Scalar{std::string_view(str)}; + } + case TypeId::kBinary: + case TypeId::kFixed: { + const auto& bytes = std::get>(literal.value()); + return Scalar{ + std::string_view(reinterpret_cast(bytes.data()), bytes.size())}; + } + case TypeId::kDecimal: + return Scalar{std::get(literal.value())}; + default: + return NotSupported("Cannot convert literal of type {} to Scalar", + literal.type()->ToString()); + } +} + +SingleValueStructLike::SingleValueStructLike(Literal literal) + : literal_(std::move(literal)) {} + +Result SingleValueStructLike::GetField(size_t /*pos*/) const { + return LiteralToScalar(literal_); +} + +size_t SingleValueStructLike::num_fields() const { return 1; } + StructLikeAccessor::StructLikeAccessor(std::shared_ptr type, std::span position_path) : type_(std::move(type)) { diff --git a/src/iceberg/row/struct_like.h b/src/iceberg/row/struct_like.h index 4999da69e..dc5fe2579 100644 --- a/src/iceberg/row/struct_like.h +++ b/src/iceberg/row/struct_like.h @@ -55,6 +55,9 @@ using Scalar = std::variant, // for list std::shared_ptr>; // for map +/// \brief Convert a Literal to a Scalar +Result LiteralToScalar(const Literal& literal); + /// \brief An immutable struct-like wrapper. class ICEBERG_EXPORT StructLike { public: @@ -68,6 +71,19 @@ class ICEBERG_EXPORT StructLike { virtual size_t num_fields() const = 0; }; +/// \brief A single-field StructLike that wraps a Literal +class ICEBERG_EXPORT SingleValueStructLike : public StructLike { + public: + explicit SingleValueStructLike(Literal literal); + + Result GetField(size_t pos) const override; + + size_t num_fields() const override; + + private: + Literal literal_; +}; + /// \brief An immutable array-like wrapper. class ICEBERG_EXPORT ArrayLike { public: diff --git a/src/iceberg/snapshot.h b/src/iceberg/snapshot.h index d41795d53..5afe2d22e 100644 --- a/src/iceberg/snapshot.h +++ b/src/iceberg/snapshot.h @@ -62,6 +62,8 @@ ICEBERG_EXPORT constexpr Result SnapshotRefTypeFromString( /// \brief A reference to a snapshot, either a branch or a tag. struct ICEBERG_EXPORT SnapshotRef { + static constexpr std::string_view kMainBranch = "main"; + struct ICEBERG_EXPORT Branch { /// A positive number for the minimum number of snapshots to keep in a branch while /// expiring snapshots. Defaults to table property diff --git a/src/iceberg/table_metadata.cc b/src/iceberg/table_metadata.cc index 780ab61da..a90c5b0ca 100644 --- a/src/iceberg/table_metadata.cc +++ b/src/iceberg/table_metadata.cc @@ -21,9 +21,12 @@ #include #include +#include #include +#include #include #include +#include #include @@ -39,12 +42,11 @@ #include "iceberg/util/gzip_internal.h" #include "iceberg/util/macros.h" #include "iceberg/util/uuid.h" - namespace iceberg { - namespace { const TimePointMs kInvalidLastUpdatedMs = TimePointMs::min(); -} +constexpr int32_t kLastAdded = -1; +} // namespace std::string ToString(const SnapshotLogEntry& entry) { return std::format("SnapshotLogEntry[timestampMillis={},snapshotId={}]", @@ -274,11 +276,19 @@ struct TableMetadataBuilder::Impl { // Change tracking std::vector> changes; + std::optional last_added_schema_id; + std::optional last_added_order_id; + std::optional last_added_spec_id; // Metadata location tracking std::optional metadata_location; std::optional previous_metadata_location; + // indexes for convenience + std::unordered_map> schemas_by_id; + std::unordered_map> specs_by_id; + std::unordered_map> sort_orders_by_id; + // Constructor for new table explicit Impl(int8_t format_version) : base(nullptr), metadata{} { metadata.format_version = format_version; @@ -294,7 +304,22 @@ struct TableMetadataBuilder::Impl { // Constructor from existing metadata explicit Impl(const TableMetadata* base_metadata) - : base(base_metadata), metadata(*base_metadata) {} + : base(base_metadata), metadata(*base_metadata) { + // Initialize index maps from base metadata + for (const auto& schema : metadata.schemas) { + if (schema->schema_id().has_value()) { + schemas_by_id.emplace(schema->schema_id().value(), schema); + } + } + + for (const auto& spec : metadata.partition_specs) { + specs_by_id.emplace(spec->spec_id(), spec); + } + + for (const auto& order : metadata.sort_orders) { + sort_orders_by_id.emplace(order->order_id(), order); + } + } }; TableMetadataBuilder::TableMetadataBuilder(int8_t format_version) @@ -434,16 +459,95 @@ TableMetadataBuilder& TableMetadataBuilder::RemoveSchemas( TableMetadataBuilder& TableMetadataBuilder::SetDefaultSortOrder( std::shared_ptr order) { - throw IcebergError(std::format("{} not implemented", __FUNCTION__)); + BUILDER_ASSIGN_OR_RETURN(auto order_id, AddSortOrderInternal(*order)); + return SetDefaultSortOrder(order_id); } TableMetadataBuilder& TableMetadataBuilder::SetDefaultSortOrder(int32_t order_id) { - throw IcebergError(std::format("{} not implemented", __FUNCTION__)); + if (order_id == -1) { + if (!impl_->last_added_order_id.has_value()) { + return AddError(ErrorKind::kInvalidArgument, + "Cannot set last added sort order: no sort order has been added"); + } + return SetDefaultSortOrder(impl_->last_added_order_id.value()); + } + + if (order_id == impl_->metadata.default_sort_order_id) { + return *this; + } + + impl_->metadata.default_sort_order_id = order_id; + + if (impl_->last_added_order_id == std::make_optional(order_id)) { + impl_->changes.push_back(std::make_unique(kLastAdded)); + } else { + impl_->changes.push_back(std::make_unique(order_id)); + } + return *this; +} + +Result TableMetadataBuilder::AddSortOrderInternal(const SortOrder& order) { + int32_t new_order_id = ReuseOrCreateNewSortOrderId(order); + + if (impl_->sort_orders_by_id.find(new_order_id) != impl_->sort_orders_by_id.end()) { + // update last_added_order_id if the order was added in this set of changes (since it + // is now the last) + bool is_new_order = + impl_->last_added_order_id.has_value() && + std::ranges::find_if(impl_->changes, [new_order_id](const auto& change) { + auto* add_sort_order = dynamic_cast(change.get()); + return add_sort_order && + add_sort_order->sort_order()->order_id() == new_order_id; + }) != impl_->changes.cend(); + impl_->last_added_order_id = + is_new_order ? std::make_optional(new_order_id) : std::nullopt; + return new_order_id; + } + + // Get current schema and validate the sort order against it + ICEBERG_ASSIGN_OR_RAISE(auto schema, impl_->metadata.Schema()); + ICEBERG_RETURN_UNEXPECTED(order.Validate(*schema)); + + std::shared_ptr new_order; + if (order.is_unsorted()) { + new_order = SortOrder::Unsorted(); + } else { + // Unlike freshSortOrder from Java impl, we don't use field name from old bound + // schema to rebuild the sort order. + ICEBERG_ASSIGN_OR_RAISE( + new_order, + SortOrder::Make(new_order_id, std::vector(order.fields().begin(), + order.fields().end()))); + } + + impl_->metadata.sort_orders.push_back(new_order); + impl_->sort_orders_by_id.emplace(new_order_id, new_order); + + impl_->changes.push_back(std::make_unique(new_order)); + impl_->last_added_order_id = new_order_id; + return new_order_id; } TableMetadataBuilder& TableMetadataBuilder::AddSortOrder( std::shared_ptr order) { - throw IcebergError(std::format("{} not implemented", __FUNCTION__)); + BUILDER_ASSIGN_OR_RETURN(auto order_id, AddSortOrderInternal(*order)); + return *this; +} + +int32_t TableMetadataBuilder::ReuseOrCreateNewSortOrderId(const SortOrder& new_order) { + if (new_order.is_unsorted()) { + return SortOrder::kUnsortedOrderId; + } + // determine the next order id + int32_t new_order_id = SortOrder::kInitialSortOrderId; + for (const auto& order : impl_->metadata.sort_orders) { + if (order->SameOrder(new_order)) { + return order->order_id(); + } else if (new_order_id <= order->order_id()) { + new_order_id = order->order_id() + 1; + } + } + return new_order_id; } TableMetadataBuilder& TableMetadataBuilder::AddSnapshot( diff --git a/src/iceberg/table_metadata.h b/src/iceberg/table_metadata.h index 503b9b143..2d53fcb08 100644 --- a/src/iceberg/table_metadata.h +++ b/src/iceberg/table_metadata.h @@ -436,6 +436,17 @@ class ICEBERG_EXPORT TableMetadataBuilder : public ErrorCollector { /// \brief Private constructor for building from existing metadata explicit TableMetadataBuilder(const TableMetadata* base); + /// \brief Internal method to add a sort order and return its ID + /// \param order The sort order to add + /// \return The ID of the added or reused sort order + Result AddSortOrderInternal(const SortOrder& order); + + /// \brief Internal method to check for existing sort order and reuse its ID or create a + /// new one + /// \param new_order The sort order to check + /// \return The ID to use for this sort order (reused if exists, new otherwise) + int32_t ReuseOrCreateNewSortOrderId(const SortOrder& new_order); + /// Internal state members struct Impl; std::unique_ptr impl_; diff --git a/src/iceberg/table_requirements.cc b/src/iceberg/table_requirements.cc index 3e0aa024f..6de6c59e6 100644 --- a/src/iceberg/table_requirements.cc +++ b/src/iceberg/table_requirements.cc @@ -21,10 +21,10 @@ #include +#include "iceberg/snapshot.h" #include "iceberg/table_metadata.h" #include "iceberg/table_requirement.h" #include "iceberg/table_update.h" -#include "iceberg/util/macros.h" namespace iceberg { @@ -36,12 +36,78 @@ Result>> TableUpdateContext::Build return std::move(requirements_); } +void TableUpdateContext::RequireLastAssignedFieldIdUnchanged() { + if (!added_last_assigned_field_id_) { + if (base_ != nullptr) { + AddRequirement( + std::make_unique(base_->last_column_id)); + } + added_last_assigned_field_id_ = true; + } +} + +void TableUpdateContext::RequireCurrentSchemaIdUnchanged() { + if (!added_current_schema_id_) { + if (base_ != nullptr && !is_replace_) { + AddRequirement(std::make_unique( + base_->current_schema_id.value())); + } + added_current_schema_id_ = true; + } +} + +void TableUpdateContext::RequireLastAssignedPartitionIdUnchanged() { + if (!added_last_assigned_partition_id_) { + if (base_ != nullptr) { + AddRequirement(std::make_unique( + base_->last_partition_id)); + } + added_last_assigned_partition_id_ = true; + } +} + +void TableUpdateContext::RequireDefaultSpecIdUnchanged() { + if (!added_default_spec_id_) { + if (base_ != nullptr && !is_replace_) { + AddRequirement( + std::make_unique(base_->default_spec_id)); + } + added_default_spec_id_ = true; + } +} + +void TableUpdateContext::RequireDefaultSortOrderIdUnchanged() { + if (!added_default_sort_order_id_) { + if (base_ != nullptr && !is_replace_) { + AddRequirement(std::make_unique( + base_->default_sort_order_id)); + } + added_default_sort_order_id_ = true; + } +} + +void TableUpdateContext::RequireNoBranchesChanged() { + if (base_ != nullptr && !is_replace_) { + for (const auto& [name, ref] : base_->refs) { + if (ref->type() == SnapshotRefType::kBranch && name != SnapshotRef::kMainBranch) { + AddRequirement( + std::make_unique(name, ref->snapshot_id)); + } + } + } +} + +bool TableUpdateContext::AddChangedRef(const std::string& ref_name) { + auto [_, inserted] = changed_refs_.insert(ref_name); + return inserted; +} + Result>> TableRequirements::ForCreateTable( const std::vector>& table_updates) { TableUpdateContext context(nullptr, false); context.AddRequirement(std::make_unique()); for (const auto& update : table_updates) { - ICEBERG_RETURN_UNEXPECTED(update->GenerateRequirements(context)); + update->GenerateRequirements(context); } return context.Build(); } @@ -52,7 +118,7 @@ Result>> TableRequirements::ForRep TableUpdateContext context(&base, true); context.AddRequirement(std::make_unique(base.table_uuid)); for (const auto& update : table_updates) { - ICEBERG_RETURN_UNEXPECTED(update->GenerateRequirements(context)); + update->GenerateRequirements(context); } return context.Build(); } @@ -63,7 +129,7 @@ Result>> TableRequirements::ForUpd TableUpdateContext context(&base, false); context.AddRequirement(std::make_unique(base.table_uuid)); for (const auto& update : table_updates) { - ICEBERG_RETURN_UNEXPECTED(update->GenerateRequirements(context)); + update->GenerateRequirements(context); } return context.Build(); } diff --git a/src/iceberg/table_requirements.h b/src/iceberg/table_requirements.h index 7af2fb2df..f79f0bead 100644 --- a/src/iceberg/table_requirements.h +++ b/src/iceberg/table_requirements.h @@ -27,6 +27,8 @@ /// for optimistic concurrency control when committing table changes. #include +#include +#include #include #include "iceberg/iceberg_export.h" @@ -68,27 +70,24 @@ class ICEBERG_EXPORT TableUpdateContext { /// \brief Build and return the list of requirements Result>> Build(); - // Getters for deduplication flags - bool added_last_assigned_field_id() const { return added_last_assigned_field_id_; } - bool added_current_schema_id() const { return added_current_schema_id_; } - bool added_last_assigned_partition_id() const { - return added_last_assigned_partition_id_; - } - bool added_default_spec_id() const { return added_default_spec_id_; } - bool added_default_sort_order_id() const { return added_default_sort_order_id_; } - - // Setters for deduplication flags - void set_added_last_assigned_field_id(bool value) { - added_last_assigned_field_id_ = value; - } - void set_added_current_schema_id(bool value) { added_current_schema_id_ = value; } - void set_added_last_assigned_partition_id(bool value) { - added_last_assigned_partition_id_ = value; - } - void set_added_default_spec_id(bool value) { added_default_spec_id_ = value; } - void set_added_default_sort_order_id(bool value) { - added_default_sort_order_id_ = value; - } + // Helper methods to deduplicate requirements to add. + /// \brief Require that the last assigned field ID remains unchanged + void RequireLastAssignedFieldIdUnchanged(); + /// \brief Require that the current schema ID remains unchanged + void RequireCurrentSchemaIdUnchanged(); + /// \brief Require that the last assigned partition ID remains unchanged + void RequireLastAssignedPartitionIdUnchanged(); + /// \brief Require that the default spec ID remains unchanged + void RequireDefaultSpecIdUnchanged(); + /// \brief Require that the default sort order ID remains unchanged + void RequireDefaultSortOrderIdUnchanged(); + /// \brief Require that no branches have been changed + void RequireNoBranchesChanged(); + + /// \brief Track a changed ref and return whether it was newly added + /// \param ref_name The name of the ref being changed + /// \return true if this is the first time the ref is being changed + bool AddChangedRef(const std::string& ref_name); private: const TableMetadata* base_; @@ -102,6 +101,9 @@ class ICEBERG_EXPORT TableUpdateContext { bool added_last_assigned_partition_id_ = false; bool added_default_spec_id_ = false; bool added_default_sort_order_id_ = false; + + // Track refs that have been changed to avoid duplicate requirements + std::unordered_set changed_refs_; }; /// \brief Factory class for generating table requirements diff --git a/src/iceberg/table_update.cc b/src/iceberg/table_update.cc index 6c1ad72e4..90f7de622 100644 --- a/src/iceberg/table_update.cc +++ b/src/iceberg/table_update.cc @@ -21,7 +21,6 @@ #include "iceberg/exception.h" #include "iceberg/table_metadata.h" -#include "iceberg/table_requirement.h" #include "iceberg/table_requirements.h" namespace iceberg::table { @@ -32,9 +31,8 @@ void AssignUUID::ApplyTo(TableMetadataBuilder& builder) const { builder.AssignUUID(uuid_); } -Status AssignUUID::GenerateRequirements(TableUpdateContext& context) const { +void AssignUUID::GenerateRequirements(TableUpdateContext& context) const { // AssignUUID does not generate additional requirements. - return {}; } // UpgradeFormatVersion @@ -43,8 +41,8 @@ void UpgradeFormatVersion::ApplyTo(TableMetadataBuilder& builder) const { throw IcebergError(std::format("{} not implemented", __FUNCTION__)); } -Status UpgradeFormatVersion::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("UpgradeFormatVersion::GenerateRequirements not implemented"); +void UpgradeFormatVersion::GenerateRequirements(TableUpdateContext& context) const { + // UpgradeFormatVersion doesn't generate any requirements } // AddSchema @@ -53,8 +51,8 @@ void AddSchema::ApplyTo(TableMetadataBuilder& builder) const { throw IcebergError(std::format("{} not implemented", __FUNCTION__)); } -Status AddSchema::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("AddTableSchema::GenerateRequirements not implemented"); +void AddSchema::GenerateRequirements(TableUpdateContext& context) const { + context.RequireLastAssignedFieldIdUnchanged(); } // SetCurrentSchema @@ -63,8 +61,8 @@ void SetCurrentSchema::ApplyTo(TableMetadataBuilder& builder) const { throw IcebergError(std::format("{} not implemented", __FUNCTION__)); } -Status SetCurrentSchema::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("SetCurrentTableSchema::GenerateRequirements not implemented"); +void SetCurrentSchema::GenerateRequirements(TableUpdateContext& context) const { + context.RequireCurrentSchemaIdUnchanged(); } // AddPartitionSpec @@ -73,8 +71,8 @@ void AddPartitionSpec::ApplyTo(TableMetadataBuilder& builder) const { throw IcebergError(std::format("{} not implemented", __FUNCTION__)); } -Status AddPartitionSpec::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("AddTablePartitionSpec::GenerateRequirements not implemented"); +void AddPartitionSpec::GenerateRequirements(TableUpdateContext& context) const { + context.RequireLastAssignedPartitionIdUnchanged(); } // SetDefaultPartitionSpec @@ -83,9 +81,8 @@ void SetDefaultPartitionSpec::ApplyTo(TableMetadataBuilder& builder) const { throw IcebergError(std::format("{} not implemented", __FUNCTION__)); } -Status SetDefaultPartitionSpec::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented( - "SetDefaultTablePartitionSpec::GenerateRequirements not implemented"); +void SetDefaultPartitionSpec::GenerateRequirements(TableUpdateContext& context) const { + context.RequireDefaultSpecIdUnchanged(); } // RemovePartitionSpecs @@ -94,9 +91,9 @@ void RemovePartitionSpecs::ApplyTo(TableMetadataBuilder& builder) const { throw IcebergError(std::format("{} not implemented", __FUNCTION__)); } -Status RemovePartitionSpecs::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented( - "RemoveTablePartitionSpecs::GenerateRequirements not implemented"); +void RemovePartitionSpecs::GenerateRequirements(TableUpdateContext& context) const { + context.RequireDefaultSpecIdUnchanged(); + context.RequireNoBranchesChanged(); } // RemoveSchemas @@ -105,28 +102,29 @@ void RemoveSchemas::ApplyTo(TableMetadataBuilder& builder) const { throw IcebergError(std::format("{} not implemented", __FUNCTION__)); } -Status RemoveSchemas::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("RemoveTableSchemas::GenerateRequirements not implemented"); +void RemoveSchemas::GenerateRequirements(TableUpdateContext& context) const { + context.RequireCurrentSchemaIdUnchanged(); + context.RequireNoBranchesChanged(); } // AddSortOrder void AddSortOrder::ApplyTo(TableMetadataBuilder& builder) const { - throw IcebergError(std::format("{} not implemented", __FUNCTION__)); + builder.AddSortOrder(sort_order_); } -Status AddSortOrder::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("AddTableSortOrder::GenerateRequirements not implemented"); +void AddSortOrder::GenerateRequirements(TableUpdateContext& context) const { + // AddSortOrder doesn't generate any requirements } // SetDefaultSortOrder void SetDefaultSortOrder::ApplyTo(TableMetadataBuilder& builder) const { - throw IcebergError(std::format("{} not implemented", __FUNCTION__)); + builder.SetDefaultSortOrder(sort_order_id_); } -Status SetDefaultSortOrder::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("SetDefaultTableSortOrder::GenerateRequirements not implemented"); +void SetDefaultSortOrder::GenerateRequirements(TableUpdateContext& context) const { + context.RequireDefaultSortOrderIdUnchanged(); } // AddSnapshot @@ -135,16 +133,16 @@ void AddSnapshot::ApplyTo(TableMetadataBuilder& builder) const { throw IcebergError(std::format("{} not implemented", __FUNCTION__)); } -Status AddSnapshot::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("AddTableSnapshot::GenerateRequirements not implemented"); +void AddSnapshot::GenerateRequirements(TableUpdateContext& context) const { + // AddSnapshot doesn't generate any requirements } // RemoveSnapshots void RemoveSnapshots::ApplyTo(TableMetadataBuilder& builder) const {} -Status RemoveSnapshots::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("RemoveTableSnapshots::GenerateRequirements not implemented"); +void RemoveSnapshots::GenerateRequirements(TableUpdateContext& context) const { + // RemoveSnapshots doesn't generate any requirements } // RemoveSnapshotRef @@ -153,8 +151,8 @@ void RemoveSnapshotRef::ApplyTo(TableMetadataBuilder& builder) const { throw IcebergError(std::format("{} not implemented", __FUNCTION__)); } -Status RemoveSnapshotRef::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("RemoveTableSnapshotRef::GenerateRequirements not implemented"); +void RemoveSnapshotRef::GenerateRequirements(TableUpdateContext& context) const { + // RemoveSnapshotRef doesn't generate any requirements } // SetSnapshotRef @@ -163,8 +161,17 @@ void SetSnapshotRef::ApplyTo(TableMetadataBuilder& builder) const { throw IcebergError(std::format("{} not implemented", __FUNCTION__)); } -Status SetSnapshotRef::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("SetTableSnapshotRef::GenerateRequirements not implemented"); +void SetSnapshotRef::GenerateRequirements(TableUpdateContext& context) const { + bool added = context.AddChangedRef(ref_name_); + if (added && context.base() != nullptr && !context.is_replace()) { + const auto& refs = context.base()->refs; + auto it = refs.find(ref_name_); + // Require that the ref does not exist (nullopt) or is the same as the base snapshot + std::optional base_snapshot_id = + (it != refs.end()) ? std::make_optional(it->second->snapshot_id) : std::nullopt; + context.AddRequirement( + std::make_unique(ref_name_, base_snapshot_id)); + } } // SetProperties @@ -173,9 +180,8 @@ void SetProperties::ApplyTo(TableMetadataBuilder& builder) const { builder.SetProperties(updated_); } -Status SetProperties::GenerateRequirements(TableUpdateContext& context) const { - // No requirements - return {}; +void SetProperties::GenerateRequirements(TableUpdateContext& context) const { + // SetProperties doesn't generate any requirements } // RemoveProperties @@ -184,9 +190,8 @@ void RemoveProperties::ApplyTo(TableMetadataBuilder& builder) const { builder.RemoveProperties(removed_); } -Status RemoveProperties::GenerateRequirements(TableUpdateContext& context) const { - // No requirements - return {}; +void RemoveProperties::GenerateRequirements(TableUpdateContext& context) const { + // RemoveProperties doesn't generate any requirements } // SetLocation @@ -195,8 +200,8 @@ void SetLocation::ApplyTo(TableMetadataBuilder& builder) const { throw IcebergError(std::format("{} not implemented", __FUNCTION__)); } -Status SetLocation::GenerateRequirements(TableUpdateContext& context) const { - return NotImplemented("SetTableLocation::GenerateRequirements not implemented"); +void SetLocation::GenerateRequirements(TableUpdateContext& context) const { + // SetLocation doesn't generate any requirements } } // namespace iceberg::table diff --git a/src/iceberg/table_update.h b/src/iceberg/table_update.h index 445295a4d..93b48cf27 100644 --- a/src/iceberg/table_update.h +++ b/src/iceberg/table_update.h @@ -58,8 +58,7 @@ class ICEBERG_EXPORT TableUpdate { /// provides information about the base metadata and operation mode. /// /// \param context The context containing base metadata and operation state - /// \return Status indicating success or failure with error details - virtual Status GenerateRequirements(TableUpdateContext& context) const = 0; + virtual void GenerateRequirements(TableUpdateContext& context) const = 0; }; namespace table { @@ -73,7 +72,7 @@ class ICEBERG_EXPORT AssignUUID : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::string uuid_; @@ -89,7 +88,7 @@ class ICEBERG_EXPORT UpgradeFormatVersion : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: int8_t format_version_; @@ -107,7 +106,7 @@ class ICEBERG_EXPORT AddSchema : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::shared_ptr schema_; @@ -123,7 +122,7 @@ class ICEBERG_EXPORT SetCurrentSchema : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: int32_t schema_id_; @@ -139,7 +138,7 @@ class ICEBERG_EXPORT AddPartitionSpec : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::shared_ptr spec_; @@ -154,7 +153,7 @@ class ICEBERG_EXPORT SetDefaultPartitionSpec : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: int32_t spec_id_; @@ -170,7 +169,7 @@ class ICEBERG_EXPORT RemovePartitionSpecs : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::vector spec_ids_; @@ -186,7 +185,7 @@ class ICEBERG_EXPORT RemoveSchemas : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::vector schema_ids_; @@ -202,7 +201,7 @@ class ICEBERG_EXPORT AddSortOrder : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::shared_ptr sort_order_; @@ -217,7 +216,7 @@ class ICEBERG_EXPORT SetDefaultSortOrder : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: int32_t sort_order_id_; @@ -233,7 +232,7 @@ class ICEBERG_EXPORT AddSnapshot : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::shared_ptr snapshot_; @@ -249,7 +248,7 @@ class ICEBERG_EXPORT RemoveSnapshots : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::vector snapshot_ids_; @@ -264,7 +263,7 @@ class ICEBERG_EXPORT RemoveSnapshotRef : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::string ref_name_; @@ -297,7 +296,7 @@ class ICEBERG_EXPORT SetSnapshotRef : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::string ref_name_; @@ -318,7 +317,7 @@ class ICEBERG_EXPORT SetProperties : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::unordered_map updated_; @@ -334,7 +333,7 @@ class ICEBERG_EXPORT RemoveProperties : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::vector removed_; @@ -349,7 +348,7 @@ class ICEBERG_EXPORT SetLocation : public TableUpdate { void ApplyTo(TableMetadataBuilder& builder) const override; - Status GenerateRequirements(TableUpdateContext& context) const override; + void GenerateRequirements(TableUpdateContext& context) const override; private: std::string location_; diff --git a/src/iceberg/test/CMakeLists.txt b/src/iceberg/test/CMakeLists.txt index a13d1f82e..9892e3d4f 100644 --- a/src/iceberg/test/CMakeLists.txt +++ b/src/iceberg/test/CMakeLists.txt @@ -87,7 +87,8 @@ add_iceberg_test(expression_test literal_test.cc inclusive_metrics_evaluator_test.cc inclusive_metrics_evaluator_with_transform_test.cc - predicate_test.cc) + predicate_test.cc + strict_metrics_evaluator_test.cc) add_iceberg_test(json_serde_test SOURCES diff --git a/src/iceberg/test/aggregate_test.cc b/src/iceberg/test/aggregate_test.cc index 264e606f7..8ee206580 100644 --- a/src/iceberg/test/aggregate_test.cc +++ b/src/iceberg/test/aggregate_test.cc @@ -23,6 +23,7 @@ #include "iceberg/expression/binder.h" #include "iceberg/expression/expressions.h" +#include "iceberg/manifest/manifest_entry.h" #include "iceberg/row/struct_like.h" #include "iceberg/schema.h" #include "iceberg/test/matchers.h" @@ -236,4 +237,243 @@ TEST(AggregateTest, MultipleAggregatesInEvaluator) { EXPECT_EQ(std::get(results[4].value()), 4); // count_star } +TEST(AggregateTest, AggregatesFromDataFileMetrics) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count_bound = BindAggregate(schema, Expressions::Count("id")); + auto count_null_bound = BindAggregate(schema, Expressions::CountNull("id")); + auto count_star_bound = BindAggregate(schema, Expressions::CountStar()); + auto max_bound = BindAggregate(schema, Expressions::Max("value")); + auto min_bound = BindAggregate(schema, Expressions::Min("value")); + + std::vector> aggregates{ + count_bound, count_null_bound, count_star_bound, max_bound, min_bound}; + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates)); + + ICEBERG_UNWRAP_OR_FAIL(auto lower, Literal::Int(5).Serialize()); + ICEBERG_UNWRAP_OR_FAIL(auto upper, Literal::Int(50).Serialize()); + DataFile file{ + .record_count = 10, + .value_counts = {{1, 10}, {2, 10}}, + .null_value_counts = {{1, 2}, {2, 0}}, + .lower_bounds = {{2, lower}}, + .upper_bounds = {{2, upper}}, + }; + + ASSERT_TRUE(evaluator->Update(file).has_value()); + + ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults()); + ASSERT_EQ(results.size(), aggregates.size()); + EXPECT_EQ(std::get(results[0].value()), 8); // count(id) = 10 - 2 + EXPECT_EQ(std::get(results[1].value()), 2); // count_null(id) + EXPECT_EQ(std::get(results[2].value()), 10); // count_star + EXPECT_EQ(std::get(results[3].value()), 50); // max(value) + EXPECT_EQ(std::get(results[4].value()), 5); // min(value) +} + +TEST(AggregateTest, AggregatesFromDataFileMissingMetricsReturnNull) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count_bound = BindAggregate(schema, Expressions::Count("id")); + auto count_null_bound = BindAggregate(schema, Expressions::CountNull("id")); + auto count_star_bound = BindAggregate(schema, Expressions::CountStar()); + auto max_bound = BindAggregate(schema, Expressions::Max("value")); + auto min_bound = BindAggregate(schema, Expressions::Min("value")); + + std::vector> aggregates{ + count_bound, count_null_bound, count_star_bound, max_bound, min_bound}; + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates)); + + DataFile file{.record_count = -1}; // missing/invalid + + ASSERT_TRUE(evaluator->Update(file).has_value()); + + ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults()); + ASSERT_EQ(results.size(), aggregates.size()); + for (const auto& literal : results) { + EXPECT_TRUE(literal.IsNull()); + } +} + +TEST(AggregateTest, AggregatesFromDataFileWithTransform) { + Schema schema({SchemaField::MakeOptional(1, "id", int32())}); + + auto truncate_id = Expressions::Truncate("id", 10); + auto max_bound = BindAggregate(schema, Expressions::Max(truncate_id)); + auto min_bound = BindAggregate(schema, Expressions::Min(truncate_id)); + + std::vector> aggregates{max_bound, min_bound}; + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates)); + + ICEBERG_UNWRAP_OR_FAIL(auto lower, Literal::Int(5).Serialize()); + ICEBERG_UNWRAP_OR_FAIL(auto upper, Literal::Int(23).Serialize()); + DataFile file{ + .record_count = 5, + .value_counts = {{1, 5}}, + .null_value_counts = {{1, 0}}, + .lower_bounds = {{1, lower}}, + .upper_bounds = {{1, upper}}, + }; + + ASSERT_TRUE(evaluator->Update(file).has_value()); + + ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults()); + ASSERT_EQ(results.size(), aggregates.size()); + // Truncate width 10: max(truncate(23)) -> 20, min(truncate(5)) -> 0 + EXPECT_EQ(std::get(results[0].value()), 20); + EXPECT_EQ(std::get(results[1].value()), 0); + EXPECT_TRUE(evaluator->AllAggregatorsValid()); +} + +TEST(AggregateTest, DataFileAggregatorParity) { + Schema schema({SchemaField::MakeRequired(1, "id", int32()), + SchemaField::MakeOptional(2, "no_stats", int32()), + SchemaField::MakeOptional(3, "all_nulls", string()), + SchemaField::MakeOptional(4, "some_nulls", string())}); + + auto make_bounds = [](int field_id, int32_t lower, int32_t upper) { + std::map> lower_bounds; + std::map> upper_bounds; + auto lser = Literal::Int(lower).Serialize().value(); + auto user = Literal::Int(upper).Serialize().value(); + lower_bounds.emplace(field_id, std::move(lser)); + upper_bounds.emplace(field_id, std::move(user)); + return std::pair{std::move(lower_bounds), std::move(upper_bounds)}; + }; + + auto [b1_lower, b1_upper] = make_bounds(1, 33, 2345); + DataFile file{ + .file_path = "file.avro", + .record_count = 50, + .value_counts = {{1, 50}, {3, 50}, {4, 50}}, + .null_value_counts = {{1, 10}, {3, 50}, {4, 10}}, + .lower_bounds = std::move(b1_lower), + .upper_bounds = std::move(b1_upper), + }; + + auto [b2_lower, b2_upper] = make_bounds(1, 33, 100); + DataFile missing_some_nulls_1{ + .file_path = "file_2.avro", + .record_count = 20, + .value_counts = {{1, 20}, {3, 20}}, + .null_value_counts = {{1, 0}, {3, 20}}, + .lower_bounds = std::move(b2_lower), + .upper_bounds = std::move(b2_upper), + }; + + auto [b3_lower, b3_upper] = make_bounds(1, -33, 3333); + DataFile missing_some_nulls_2{ + .file_path = "file_3.avro", + .record_count = 20, + .value_counts = {{1, 20}, {3, 20}}, + .null_value_counts = {{1, 20}, {3, 20}}, + .lower_bounds = std::move(b3_lower), + .upper_bounds = std::move(b3_upper), + }; + + DataFile missing_some_stats{ + .file_path = "file_missing_stats.avro", + .record_count = 20, + .value_counts = {{1, 20}, {4, 10}}, + }; + auto [b4_lower, b4_upper] = make_bounds(1, -3, 1333); + missing_some_stats.lower_bounds = std::move(b4_lower); + missing_some_stats.upper_bounds = std::move(b4_upper); + + DataFile missing_all_optional_stats{ + .file_path = "file_null_stats.avro", + .record_count = 20, + }; + + auto run_case = [&](const std::vector>& exprs, + const std::vector& files, + const std::vector>& expected, + bool expect_all_valid) { + std::vector> aggregates; + aggregates.reserve(exprs.size()); + for (const auto& e : exprs) { + aggregates.emplace_back(BindAggregate(schema, e)); + } + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates)); + for (const auto& f : files) { + ASSERT_TRUE(evaluator->Update(f).has_value()); + } + EXPECT_EQ(evaluator->AllAggregatorsValid(), expect_all_valid); + ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults()); + ASSERT_EQ(results.size(), expected.size()); + for (size_t i = 0; i < expected.size(); ++i) { + if (!expected[i].has_value()) { + EXPECT_TRUE(results[i].IsNull()); + } else { + const auto& exp = *expected[i]; + const auto& res = results[i].value(); + if (std::holds_alternative(exp)) { + EXPECT_EQ(std::get(res), std::get(exp)); + } else if (std::holds_alternative(exp)) { + EXPECT_EQ(std::get(res), std::get(exp)); + } else { + FAIL() << "Unexpected expected type"; + } + } + } + }; + + // testIntAggregate + run_case({Expressions::CountStar(), Expressions::Count("id"), + Expressions::CountNull("id"), Expressions::Max("id"), Expressions::Min("id")}, + {file, missing_some_nulls_1, missing_some_nulls_2}, + {Scalar{int64_t{90}}, Scalar{int64_t{60}}, Scalar{int64_t{30}}, + Scalar{int32_t{3333}}, Scalar{int32_t{-33}}}, + /*expect_all_valid=*/true); + + // testAllNulls + run_case({Expressions::CountStar(), Expressions::Count("all_nulls"), + Expressions::CountNull("all_nulls"), Expressions::Max("all_nulls"), + Expressions::Min("all_nulls")}, + {file, missing_some_nulls_1, missing_some_nulls_2}, + {Scalar{int64_t{90}}, Scalar{int64_t{0}}, Scalar{int64_t{90}}, std::nullopt, + std::nullopt}, + /*expect_all_valid=*/true); + + // testSomeNulls -> missing null counts for field 4 + run_case({Expressions::CountStar(), Expressions::Count("some_nulls"), + Expressions::CountNull("some_nulls"), Expressions::Max("some_nulls"), + Expressions::Min("some_nulls")}, + {file, missing_some_nulls_1, missing_some_nulls_2}, + {Scalar{int64_t{90}}, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + /*expect_all_valid=*/false); + + // testNoStats -> field 2 has no metrics + run_case({Expressions::CountStar(), Expressions::Count("no_stats"), + Expressions::CountNull("no_stats"), Expressions::Max("no_stats"), + Expressions::Min("no_stats")}, + {file, missing_some_nulls_1, missing_some_nulls_2}, + {Scalar{int64_t{90}}, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + /*expect_all_valid=*/false); + + // testIntAggregateAllMissingStats -> id missing optional stats + run_case({Expressions::CountStar(), Expressions::Count("id"), + Expressions::CountNull("id"), Expressions::Max("id"), Expressions::Min("id")}, + {missing_all_optional_stats}, + {Scalar{int64_t{20}}, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + /*expect_all_valid=*/false); + + // testOptionalColAllMissingStats -> field 2 missing everything + run_case({Expressions::CountStar(), Expressions::Count("no_stats"), + Expressions::CountNull("no_stats"), Expressions::Max("no_stats"), + Expressions::Min("no_stats")}, + {missing_all_optional_stats}, + {Scalar{int64_t{20}}, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + /*expect_all_valid=*/false); + + // testMissingSomeStats -> some_nulls missing null stats entirely + run_case({Expressions::CountStar(), Expressions::Count("some_nulls"), + Expressions::Max("some_nulls"), Expressions::Min("some_nulls")}, + {missing_some_stats}, + {Scalar{int64_t{20}}, std::nullopt, std::nullopt, std::nullopt}, + /*expect_all_valid=*/false); +} + } // namespace iceberg diff --git a/src/iceberg/test/meson.build b/src/iceberg/test/meson.build index 4cb153ba1..c73abe188 100644 --- a/src/iceberg/test/meson.build +++ b/src/iceberg/test/meson.build @@ -65,6 +65,7 @@ iceberg_tests = { 'inclusive_metrics_evaluator_with_transform_test.cc', 'literal_test.cc', 'predicate_test.cc', + 'strict_metrics_evaluator_test.cc', ), }, 'json_serde_test': { diff --git a/src/iceberg/test/predicate_test.cc b/src/iceberg/test/predicate_test.cc index 532e908b4..fab0b5617 100644 --- a/src/iceberg/test/predicate_test.cc +++ b/src/iceberg/test/predicate_test.cc @@ -26,7 +26,6 @@ #include "iceberg/schema.h" #include "iceberg/test/matchers.h" #include "iceberg/type.h" -#include "iceberg/util/macros.h" namespace iceberg { @@ -607,24 +606,24 @@ std::shared_ptr AssertAndCastToBoundPredicate( } // namespace TEST_F(PredicateTest, BoundUnaryPredicateTestIsNull) { - ICEBERG_ASSIGN_OR_THROW(auto is_null_pred, Expressions::IsNull("name")->Bind( - *schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto is_null_pred, Expressions::IsNull("name")->Bind( + *schema_, /*case_sensitive=*/true)); auto bound_pred = AssertAndCastToBoundPredicate(is_null_pred); EXPECT_THAT(bound_pred->Test(Literal::Null(string())), HasValue(testing::Eq(true))); EXPECT_THAT(bound_pred->Test(Literal::String("test")), HasValue(testing::Eq(false))); } TEST_F(PredicateTest, BoundUnaryPredicateTestNotNull) { - ICEBERG_ASSIGN_OR_THROW(auto not_null_pred, Expressions::NotNull("name")->Bind( - *schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto not_null_pred, Expressions::NotNull("name")->Bind( + *schema_, /*case_sensitive=*/true)); auto bound_pred = AssertAndCastToBoundPredicate(not_null_pred); EXPECT_THAT(bound_pred->Test(Literal::String("test")), HasValue(testing::Eq(true))); EXPECT_THAT(bound_pred->Test(Literal::Null(string())), HasValue(testing::Eq(false))); } TEST_F(PredicateTest, BoundUnaryPredicateTestIsNaN) { - ICEBERG_ASSIGN_OR_THROW(auto is_nan_pred, Expressions::IsNaN("salary")->Bind( - *schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto is_nan_pred, Expressions::IsNaN("salary")->Bind( + *schema_, /*case_sensitive=*/true)); auto bound_pred = AssertAndCastToBoundPredicate(is_nan_pred); // Test with NaN values @@ -643,8 +642,8 @@ TEST_F(PredicateTest, BoundUnaryPredicateTestIsNaN) { } TEST_F(PredicateTest, BoundUnaryPredicateTestNotNaN) { - ICEBERG_ASSIGN_OR_THROW(auto not_nan_pred, Expressions::NotNaN("salary")->Bind( - *schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto not_nan_pred, Expressions::NotNaN("salary")->Bind( + *schema_, /*case_sensitive=*/true)); auto bound_pred = AssertAndCastToBoundPredicate(not_nan_pred); // Test with regular values @@ -661,34 +660,34 @@ TEST_F(PredicateTest, BoundUnaryPredicateTestNotNaN) { TEST_F(PredicateTest, BoundLiteralPredicateTestComparison) { // Test less than - ICEBERG_ASSIGN_OR_THROW(auto lt_pred, Expressions::LessThan("age", Literal::Int(30)) - ->Bind(*schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto lt_pred, Expressions::LessThan("age", Literal::Int(30)) + ->Bind(*schema_, /*case_sensitive=*/true)); auto bound_lt = AssertAndCastToBoundPredicate(lt_pred); EXPECT_THAT(bound_lt->Test(Literal::Int(20)), HasValue(testing::Eq(true))); EXPECT_THAT(bound_lt->Test(Literal::Int(30)), HasValue(testing::Eq(false))); EXPECT_THAT(bound_lt->Test(Literal::Int(40)), HasValue(testing::Eq(false))); // Test less than or equal - ICEBERG_ASSIGN_OR_THROW(auto lte_pred, - Expressions::LessThanOrEqual("age", Literal::Int(30)) - ->Bind(*schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto lte_pred, + Expressions::LessThanOrEqual("age", Literal::Int(30)) + ->Bind(*schema_, /*case_sensitive=*/true)); auto bound_lte = AssertAndCastToBoundPredicate(lte_pred); EXPECT_THAT(bound_lte->Test(Literal::Int(20)), HasValue(testing::Eq(true))); EXPECT_THAT(bound_lte->Test(Literal::Int(30)), HasValue(testing::Eq(true))); EXPECT_THAT(bound_lte->Test(Literal::Int(40)), HasValue(testing::Eq(false))); // Test greater than - ICEBERG_ASSIGN_OR_THROW(auto gt_pred, Expressions::GreaterThan("age", Literal::Int(30)) - ->Bind(*schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto gt_pred, Expressions::GreaterThan("age", Literal::Int(30)) + ->Bind(*schema_, /*case_sensitive=*/true)); auto bound_gt = AssertAndCastToBoundPredicate(gt_pred); EXPECT_THAT(bound_gt->Test(Literal::Int(20)), HasValue(testing::Eq(false))); EXPECT_THAT(bound_gt->Test(Literal::Int(30)), HasValue(testing::Eq(false))); EXPECT_THAT(bound_gt->Test(Literal::Int(40)), HasValue(testing::Eq(true))); // Test greater than or equal - ICEBERG_ASSIGN_OR_THROW(auto gte_pred, - Expressions::GreaterThanOrEqual("age", Literal::Int(30)) - ->Bind(*schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto gte_pred, + Expressions::GreaterThanOrEqual("age", Literal::Int(30)) + ->Bind(*schema_, /*case_sensitive=*/true)); auto bound_gte = AssertAndCastToBoundPredicate(gte_pred); EXPECT_THAT(bound_gte->Test(Literal::Int(20)), HasValue(testing::Eq(false))); EXPECT_THAT(bound_gte->Test(Literal::Int(30)), HasValue(testing::Eq(true))); @@ -697,16 +696,16 @@ TEST_F(PredicateTest, BoundLiteralPredicateTestComparison) { TEST_F(PredicateTest, BoundLiteralPredicateTestEquality) { // Test equal - ICEBERG_ASSIGN_OR_THROW(auto eq_pred, Expressions::Equal("age", Literal::Int(25)) - ->Bind(*schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto eq_pred, Expressions::Equal("age", Literal::Int(25)) + ->Bind(*schema_, /*case_sensitive=*/true)); auto bound_eq = AssertAndCastToBoundPredicate(eq_pred); EXPECT_THAT(bound_eq->Test(Literal::Int(25)), HasValue(testing::Eq(true))); EXPECT_THAT(bound_eq->Test(Literal::Int(26)), HasValue(testing::Eq(false))); EXPECT_THAT(bound_eq->Test(Literal::Int(24)), HasValue(testing::Eq(false))); // Test not equal - ICEBERG_ASSIGN_OR_THROW(auto neq_pred, Expressions::NotEqual("age", Literal::Int(25)) - ->Bind(*schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto neq_pred, Expressions::NotEqual("age", Literal::Int(25)) + ->Bind(*schema_, /*case_sensitive=*/true)); auto bound_neq = AssertAndCastToBoundPredicate(neq_pred); EXPECT_THAT(bound_neq->Test(Literal::Int(25)), HasValue(testing::Eq(false))); EXPECT_THAT(bound_neq->Test(Literal::Int(26)), HasValue(testing::Eq(true))); @@ -715,18 +714,18 @@ TEST_F(PredicateTest, BoundLiteralPredicateTestEquality) { TEST_F(PredicateTest, BoundLiteralPredicateTestWithDifferentTypes) { // Test with double - ICEBERG_ASSIGN_OR_THROW(auto gt_pred, - Expressions::GreaterThan("salary", Literal::Double(50000.0)) - ->Bind(*schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto gt_pred, + Expressions::GreaterThan("salary", Literal::Double(50000.0)) + ->Bind(*schema_, /*case_sensitive=*/true)); auto bound_double = AssertAndCastToBoundPredicate(gt_pred); EXPECT_THAT(bound_double->Test(Literal::Double(60000.0)), HasValue(testing::Eq(true))); EXPECT_THAT(bound_double->Test(Literal::Double(40000.0)), HasValue(testing::Eq(false))); EXPECT_THAT(bound_double->Test(Literal::Double(50000.0)), HasValue(testing::Eq(false))); // Test with string - ICEBERG_ASSIGN_OR_THROW(auto str_eq_pred, - Expressions::Equal("name", Literal::String("Alice")) - ->Bind(*schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto str_eq_pred, + Expressions::Equal("name", Literal::String("Alice")) + ->Bind(*schema_, /*case_sensitive=*/true)); auto bound_string = AssertAndCastToBoundPredicate(str_eq_pred); EXPECT_THAT(bound_string->Test(Literal::String("Alice")), HasValue(testing::Eq(true))); EXPECT_THAT(bound_string->Test(Literal::String("Bob")), HasValue(testing::Eq(false))); @@ -734,16 +733,16 @@ TEST_F(PredicateTest, BoundLiteralPredicateTestWithDifferentTypes) { HasValue(testing::Eq(false))); // Case sensitive // Test with boolean - ICEBERG_ASSIGN_OR_THROW(auto bool_eq_pred, - Expressions::Equal("active", Literal::Boolean(true)) - ->Bind(*schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bool_eq_pred, + Expressions::Equal("active", Literal::Boolean(true)) + ->Bind(*schema_, /*case_sensitive=*/true)); auto bound_bool = AssertAndCastToBoundPredicate(bool_eq_pred); EXPECT_THAT(bound_bool->Test(Literal::Boolean(true)), HasValue(testing::Eq(true))); EXPECT_THAT(bound_bool->Test(Literal::Boolean(false)), HasValue(testing::Eq(false))); } TEST_F(PredicateTest, BoundLiteralPredicateTestStartsWith) { - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto starts_with_pred, Expressions::StartsWith("name", "Jo")->Bind(*schema_, /*case_sensitive=*/true)); auto bound_pred = AssertAndCastToBoundPredicate(starts_with_pred); @@ -759,7 +758,7 @@ TEST_F(PredicateTest, BoundLiteralPredicateTestStartsWith) { EXPECT_THAT(bound_pred->Test(Literal::String("")), HasValue(testing::Eq(false))); // Test empty prefix - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto empty_prefix_pred, Expressions::StartsWith("name", "")->Bind(*schema_, /*case_sensitive=*/true)); auto bound_empty = AssertAndCastToBoundPredicate(empty_prefix_pred); @@ -770,7 +769,7 @@ TEST_F(PredicateTest, BoundLiteralPredicateTestStartsWith) { } TEST_F(PredicateTest, BoundLiteralPredicateTestNotStartsWith) { - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto not_starts_with_pred, Expressions::NotStartsWith("name", "Jo")->Bind(*schema_, /*case_sensitive=*/true)); auto bound_pred = AssertAndCastToBoundPredicate(not_starts_with_pred); @@ -787,7 +786,7 @@ TEST_F(PredicateTest, BoundLiteralPredicateTestNotStartsWith) { } TEST_F(PredicateTest, BoundSetPredicateTestIn) { - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto in_pred, Expressions::In("age", {Literal::Int(10), Literal::Int(20), Literal::Int(30)}) ->Bind(*schema_, /*case_sensitive=*/true)); @@ -805,7 +804,7 @@ TEST_F(PredicateTest, BoundSetPredicateTestIn) { } TEST_F(PredicateTest, BoundSetPredicateTestNotIn) { - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto not_in_pred, Expressions::NotIn("age", {Literal::Int(10), Literal::Int(20), Literal::Int(30)}) ->Bind(*schema_, /*case_sensitive=*/true)); @@ -823,7 +822,7 @@ TEST_F(PredicateTest, BoundSetPredicateTestNotIn) { } TEST_F(PredicateTest, BoundSetPredicateTestWithStrings) { - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto in_pred, Expressions::In("name", {Literal::String("Alice"), Literal::String("Bob"), Literal::String("Charlie")}) @@ -843,10 +842,10 @@ TEST_F(PredicateTest, BoundSetPredicateTestWithStrings) { } TEST_F(PredicateTest, BoundSetPredicateTestWithLongs) { - ICEBERG_ASSIGN_OR_THROW(auto in_pred, - Expressions::In("id", {Literal::Long(100L), Literal::Long(200L), - Literal::Long(300L)}) - ->Bind(*schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto in_pred, + Expressions::In("id", {Literal::Long(100L), Literal::Long(200L), + Literal::Long(300L)}) + ->Bind(*schema_, /*case_sensitive=*/true)); auto bound_pred = AssertAndCastToBoundPredicate(in_pred); // Test longs in the set @@ -860,8 +859,8 @@ TEST_F(PredicateTest, BoundSetPredicateTestWithLongs) { } TEST_F(PredicateTest, BoundSetPredicateTestSingleLiteral) { - ICEBERG_ASSIGN_OR_THROW(auto in_pred, Expressions::In("age", {Literal::Int(42)}) - ->Bind(*schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto in_pred, Expressions::In("age", {Literal::Int(42)}) + ->Bind(*schema_, /*case_sensitive=*/true)); // Single element IN becomes Equal EXPECT_EQ(in_pred->op(), Expression::Operation::kEq); diff --git a/src/iceberg/test/strict_metrics_evaluator_test.cc b/src/iceberg/test/strict_metrics_evaluator_test.cc new file mode 100644 index 000000000..fa6185c3b --- /dev/null +++ b/src/iceberg/test/strict_metrics_evaluator_test.cc @@ -0,0 +1,849 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/expression/strict_metrics_evaluator.h" + +#include + +#include + +#include "iceberg/expression/binder.h" +#include "iceberg/expression/expressions.h" +#include "iceberg/manifest/manifest_entry.h" +#include "iceberg/schema.h" +#include "iceberg/test/matchers.h" +#include "iceberg/type.h" + +namespace iceberg { + +namespace { +constexpr bool kRowsMustMatch = true; +constexpr bool kRowsMightNotMatch = false; +} // namespace +using TestVariant = std::variant; + +class StrictMetricsEvaluatorTest : public ::testing::Test { + protected: + void SetUp() override { + schema_ = std::make_shared( + std::vector{ + SchemaField::MakeRequired(1, "id", int64()), + SchemaField::MakeOptional(2, "name", string()), + SchemaField::MakeRequired(3, "age", int32()), + SchemaField::MakeOptional(4, "salary", float64()), + SchemaField::MakeRequired(5, "active", boolean()), + SchemaField::MakeRequired(6, "date", string()), + }, + /*schema_id=*/0); + } + + Result> Bind(const std::shared_ptr& expr, + bool case_sensitive = true) { + return Binder::Bind(*schema_, expr, case_sensitive); + } + + std::shared_ptr PrepareDataFile( + const std::string& partition, int64_t record_count, int64_t file_size_in_bytes, + const std::map& lower_bounds, + const std::map& upper_bounds, + const std::map& value_counts = {}, + const std::map& null_counts = {}, + const std::map& nan_counts = {}) { + auto parse_bound = [&](const std::map& bounds, + std::map>& bound_values) { + for (const auto& [key, value] : bounds) { + if (key == "id") { + bound_values[1] = Literal::Long(std::get(value)).Serialize().value(); + } else if (key == "name") { + bound_values[2] = + Literal::String(std::get(value)).Serialize().value(); + } else if (key == "age") { + bound_values[3] = Literal::Int(std::get(value)).Serialize().value(); + } else if (key == "salary") { + bound_values[4] = Literal::Double(std::get(value)).Serialize().value(); + } else if (key == "active") { + bound_values[5] = Literal::Boolean(std::get(value)).Serialize().value(); + } + } + }; + + auto data_file = std::make_shared(); + data_file->file_path = "test_path"; + data_file->file_format = FileFormatType::kParquet; + data_file->partition.AddValue(Literal::String(partition)); + data_file->record_count = record_count; + data_file->file_size_in_bytes = file_size_in_bytes; + data_file->column_sizes = {}; + data_file->value_counts = value_counts; + data_file->null_value_counts = null_counts; + data_file->nan_value_counts = nan_counts; + data_file->split_offsets = {1}; + data_file->sort_order_id = 0; + parse_bound(upper_bounds, data_file->upper_bounds); + parse_bound(lower_bounds, data_file->lower_bounds); + return data_file; + } + + void TestCase(const std::shared_ptr& unbound, bool expected_result) { + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile(/*partition=*/"20251128", /*record_count=*/10, + /*file_size_in_bytes=*/1024, + /*lower_bounds=*/{{"id", static_cast(100)}}, + /*upper_bounds=*/{{"id", static_cast(200)}}, + /*value_counts=*/{{1, 10}}, /*null_counts=*/{{1, 0}}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), expected_result) << unbound->ToString(); + } + + void TestStringCase(const std::shared_ptr& unbound, bool expected_result) { + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile(/*partition=*/"20251128", /*record_count=*/10, + /*file_size_in_bytes=*/1024, + /*lower_bounds=*/{{"name", "123"}}, {{"name", "456"}}, + /*value_counts=*/{{2, 10}}, /*null_counts=*/{{2, 0}}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), expected_result) << unbound->ToString(); + } + + std::shared_ptr schema_; +}; + +TEST_F(StrictMetricsEvaluatorTest, CaseSensitiveTest) { + { + auto unbound = Expressions::Equal("id", Literal::Long(300)); + auto evaluator = StrictMetricsEvaluator::Make(unbound, schema_, true); + ASSERT_TRUE(evaluator.has_value()); + } + { + auto unbound = Expressions::Equal("ID", Literal::Long(300)); + auto evaluator = StrictMetricsEvaluator::Make(unbound, schema_, true); + ASSERT_FALSE(evaluator.has_value()); + ASSERT_EQ(evaluator.error().kind, ErrorKind::kInvalidExpression); + } + { + auto unbound = Expressions::Equal("ID", Literal::Long(300)); + auto evaluator = StrictMetricsEvaluator::Make(unbound, schema_, false); + ASSERT_TRUE(evaluator.has_value()); + } +} + +TEST_F(StrictMetricsEvaluatorTest, IsNullTest) { + { + auto unbound = Expressions::IsNull("name"); + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile("20251128", 10, 1024, {{"name", "1"}}, {{"name", "2"}}, + {{2, 10}}, {{2, 5}}, {}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), kRowsMightNotMatch) << unbound->ToString(); + } + { + auto unbound = Expressions::IsNull("name"); + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile("20251128", 10, 1024, {{"name", "1"}}, {{"name", "2"}}, + {{2, 10}}, {{2, 10}}, {}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), kRowsMustMatch) << unbound->ToString(); + } +} + +TEST_F(StrictMetricsEvaluatorTest, NotNullTest) { + { + auto unbound = Expressions::NotNull("name"); + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile("20251128", 10, 1024, {{"name", "1"}}, {{"name", "2"}}, + {{2, 10}}, {{2, 5}}, {}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), kRowsMightNotMatch) << unbound->ToString(); + } + { + auto unbound = Expressions::NotNull("name"); + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile("20251128", 10, 1024, {{"name", "1"}}, {{"name", "2"}}, + {{2, 10}}, {{2, 0}}, {}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), kRowsMustMatch) << unbound->ToString(); + } +} + +TEST_F(StrictMetricsEvaluatorTest, IsNanTest) { + { + auto unbound = Expressions::IsNaN("salary"); + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile("20251128", 10, 1024, {{"salary", 1.0}}, + {{"salary", 2.0}}, {{4, 10}}, {{4, 5}}, {{4, 5}}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), kRowsMightNotMatch) << unbound->ToString(); + } + { + auto unbound = Expressions::IsNaN("salary"); + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile("20251128", 10, 1024, {{"salary", 1.0}}, + {{"salary", 2.0}}, {{4, 10}}, {{4, 10}}, {{4, 5}}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), kRowsMightNotMatch) << unbound->ToString(); + } + { + auto unbound = Expressions::IsNaN("salary"); + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile("20251128", 10, 1024, {{"salary", 1.0}}, + {{"salary", 2.0}}, {{4, 10}}, {{4, 5}}, {{4, 10}}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), kRowsMustMatch) << unbound->ToString(); + } +} + +TEST_F(StrictMetricsEvaluatorTest, NotNanTest) { + { + auto unbound = Expressions::NotNaN("salary"); + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile("20251128", 10, 1024, {{"salary", 1.0}}, + {{"salary", 2.0}}, {{4, 10}}, {}, {{4, 5}}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), kRowsMightNotMatch) << unbound->ToString(); + } + { + auto unbound = Expressions::NotNaN("salary"); + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile("20251128", 10, 1024, {{"salary", 1.0}}, + {{"salary", 2.0}}, {{4, 10}}, {}, {{4, 0}}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), kRowsMustMatch) << unbound->ToString(); + } + { + auto unbound = Expressions::NotNaN("salary"); + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile("20251128", 10, 1024, {{"salary", 1.0}}, + {{"salary", 2.0}}, {{4, 10}}, {{4, 10}}, {}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), kRowsMustMatch) << unbound->ToString(); + } +} + +TEST_F(StrictMetricsEvaluatorTest, LTTest) { + TestCase(Expressions::LessThan("id", Literal::Long(300)), kRowsMustMatch); + TestCase(Expressions::LessThan("id", Literal::Long(150)), kRowsMightNotMatch); + TestCase(Expressions::LessThan("id", Literal::Long(100)), kRowsMightNotMatch); + TestCase(Expressions::LessThan("id", Literal::Long(200)), kRowsMightNotMatch); + TestCase(Expressions::LessThan("id", Literal::Long(99)), kRowsMightNotMatch); +} + +TEST_F(StrictMetricsEvaluatorTest, LTEQTest) { + TestCase(Expressions::LessThanOrEqual("id", Literal::Long(300)), kRowsMustMatch); + TestCase(Expressions::LessThanOrEqual("id", Literal::Long(150)), kRowsMightNotMatch); + TestCase(Expressions::LessThanOrEqual("id", Literal::Long(100)), kRowsMightNotMatch); + TestCase(Expressions::LessThanOrEqual("id", Literal::Long(200)), kRowsMustMatch); + TestCase(Expressions::LessThanOrEqual("id", Literal::Long(99)), kRowsMightNotMatch); +} + +TEST_F(StrictMetricsEvaluatorTest, GTTest) { + TestCase(Expressions::GreaterThan("id", Literal::Long(300)), kRowsMightNotMatch); + TestCase(Expressions::GreaterThan("id", Literal::Long(150)), kRowsMightNotMatch); + TestCase(Expressions::GreaterThan("id", Literal::Long(100)), kRowsMightNotMatch); + TestCase(Expressions::GreaterThan("id", Literal::Long(200)), kRowsMightNotMatch); + TestCase(Expressions::GreaterThan("id", Literal::Long(99)), kRowsMustMatch); +} + +TEST_F(StrictMetricsEvaluatorTest, GTEQTest) { + TestCase(Expressions::GreaterThanOrEqual("id", Literal::Long(300)), kRowsMightNotMatch); + TestCase(Expressions::GreaterThanOrEqual("id", Literal::Long(150)), kRowsMightNotMatch); + TestCase(Expressions::GreaterThanOrEqual("id", Literal::Long(100)), kRowsMustMatch); + TestCase(Expressions::GreaterThanOrEqual("id", Literal::Long(200)), kRowsMightNotMatch); + TestCase(Expressions::GreaterThanOrEqual("id", Literal::Long(99)), kRowsMustMatch); +} + +TEST_F(StrictMetricsEvaluatorTest, EQTest) { + TestCase(Expressions::Equal("id", Literal::Long(300)), kRowsMightNotMatch); + TestCase(Expressions::Equal("id", Literal::Long(150)), kRowsMightNotMatch); + TestCase(Expressions::Equal("id", Literal::Long(100)), kRowsMightNotMatch); + TestCase(Expressions::Equal("id", Literal::Long(200)), kRowsMightNotMatch); + + auto test_case = [&](const std::shared_ptr& unbound, bool expected_result) { + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile(/*partition=*/"20251128", /*record_count=*/10, + /*file_size_in_bytes=*/1024, + /*lower_bounds=*/{{"id", static_cast(100)}}, + /*upper_bounds=*/{{"id", static_cast(100)}}, + /*value_counts=*/{{1, 10}}, /*null_counts=*/{{1, 0}}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), expected_result) << unbound->ToString(); + }; + test_case(Expressions::Equal("id", Literal::Long(100)), kRowsMustMatch); + test_case(Expressions::Equal("id", Literal::Long(200)), kRowsMightNotMatch); +} + +TEST_F(StrictMetricsEvaluatorTest, NotEqTest) { + TestCase(Expressions::NotEqual("id", Literal::Long(300)), kRowsMustMatch); + TestCase(Expressions::NotEqual("id", Literal::Long(150)), kRowsMightNotMatch); + TestCase(Expressions::NotEqual("id", Literal::Long(100)), kRowsMightNotMatch); + TestCase(Expressions::NotEqual("id", Literal::Long(200)), kRowsMightNotMatch); + TestCase(Expressions::NotEqual("id", Literal::Long(99)), kRowsMustMatch); +} + +TEST_F(StrictMetricsEvaluatorTest, InTest) { + TestCase(Expressions::In("id", + { + Literal::Long(100), + Literal::Long(200), + Literal::Long(300), + Literal::Long(400), + Literal::Long(500), + }), + kRowsMightNotMatch); + + auto test_case = [&](const std::shared_ptr& unbound, bool expected_result) { + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(unbound, schema_, true)); + auto file = PrepareDataFile(/*partition=*/"20251128", /*record_count=*/10, + /*file_size_in_bytes=*/1024, + /*lower_bounds=*/{{"id", static_cast(100)}}, + /*upper_bounds=*/{{"id", static_cast(100)}}, + /*value_counts=*/{{1, 10}}, /*null_counts=*/{{1, 0}}); + auto result = evaluator->Evaluate(*file); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value(), expected_result) << unbound->ToString(); + }; + test_case(Expressions::In("id", {Literal::Long(100), Literal::Long(200)}), + kRowsMustMatch); + test_case(Expressions::In("id", {Literal::Long(200), Literal::Long(300)}), + kRowsMightNotMatch); +} + +TEST_F(StrictMetricsEvaluatorTest, NotInTest) { + TestCase(Expressions::NotIn("id", + { + Literal::Long(88), + Literal::Long(99), + }), + kRowsMustMatch); + TestCase(Expressions::NotIn("id", + { + Literal::Long(288), + Literal::Long(299), + }), + kRowsMustMatch); + TestCase(Expressions::NotIn("id", + { + Literal::Long(88), + Literal::Long(288), + Literal::Long(299), + }), + kRowsMustMatch); + TestCase(Expressions::NotIn("id", + { + Literal::Long(88), + Literal::Long(100), + }), + kRowsMightNotMatch); + TestCase(Expressions::NotIn("id", + { + Literal::Long(88), + Literal::Long(101), + }), + kRowsMightNotMatch); + TestCase(Expressions::NotIn("id", + { + Literal::Long(100), + Literal::Long(101), + }), + kRowsMightNotMatch); +} + +TEST_F(StrictMetricsEvaluatorTest, StartsWithTest) { + // always true + TestStringCase(Expressions::StartsWith("name", "1"), kRowsMightNotMatch); +} + +TEST_F(StrictMetricsEvaluatorTest, NotStartsWithTest) { + TestStringCase(Expressions::NotStartsWith("name", "1"), kRowsMightNotMatch); +} + +class StrictMetricsEvaluatorMigratedTest : public StrictMetricsEvaluatorTest { + protected: + static constexpr int64_t kIntMinValue = 30; + static constexpr int64_t kIntMaxValue = 79; + static constexpr int64_t kAlwaysFive = 5; + + void SetUp() override { + schema_ = std::make_shared( + std::vector{ + SchemaField::MakeRequired(1, "id", int64()), + SchemaField::MakeOptional(2, "no_stats", int64()), + SchemaField::MakeRequired(3, "required", string()), + SchemaField::MakeOptional(4, "all_nulls", string()), + SchemaField::MakeOptional(5, "some_nulls", string()), + SchemaField::MakeOptional(6, "no_nulls", string()), + SchemaField::MakeRequired(7, "always_5", int64()), + SchemaField::MakeOptional(8, "all_nans", float64()), + SchemaField::MakeOptional(9, "some_nans", float32()), + SchemaField::MakeOptional(10, "no_nans", float32()), + SchemaField::MakeOptional(11, "all_nulls_double", float64()), + SchemaField::MakeOptional(12, "all_nans_v1_stats", float32()), + SchemaField::MakeOptional(13, "nan_and_null_only", float64()), + SchemaField::MakeOptional(14, "no_nan_stats", float64()), + SchemaField::MakeOptional( + 15, "struct", + std::make_shared(std::vector{ + SchemaField::MakeOptional(16, "nested_col_no_stats", int64()), + SchemaField::MakeOptional(17, "nested_col_with_stats", int64())})), + }, + /*schema_id=*/0); + + file_ = MakePrimaryFile(); + file_with_bounds_ = MakeSomeNullsFile(); + file_with_equal_bounds_ = MakeSomeNullsEqualBoundsFile(); + } + + std::shared_ptr MakePrimaryFile() { + auto data_file = std::make_shared(); + data_file->file_path = "file.avro"; + data_file->file_format = FileFormatType::kParquet; + data_file->record_count = 50; + data_file->value_counts = { + {4, 50L}, {5, 50L}, {6, 50L}, {8, 50L}, {9, 50L}, {10, 50L}, + {11, 50L}, {12, 50L}, {13, 50L}, {14, 50L}, {17, 50L}, + }; + data_file->null_value_counts = { + {4, 50L}, {5, 10L}, {6, 0L}, {11, 50L}, {12, 0L}, {13, 1L}, {17, 0L}, + }; + data_file->nan_value_counts = { + {8, 50L}, + {9, 10L}, + {10, 0L}, + }; + const float float_nan = std::numeric_limits::quiet_NaN(); + const double double_nan = std::numeric_limits::quiet_NaN(); + data_file->lower_bounds = { + {1, Literal::Long(kIntMinValue).Serialize().value()}, + {7, Literal::Long(kAlwaysFive).Serialize().value()}, + {12, Literal::Float(float_nan).Serialize().value()}, + {13, Literal::Double(double_nan).Serialize().value()}, + {17, Literal::Long(kIntMinValue).Serialize().value()}, + }; + data_file->upper_bounds = { + {1, Literal::Long(kIntMaxValue).Serialize().value()}, + {7, Literal::Long(kAlwaysFive).Serialize().value()}, + {12, Literal::Float(float_nan).Serialize().value()}, + {13, Literal::Double(double_nan).Serialize().value()}, + {17, Literal::Long(kIntMaxValue).Serialize().value()}, + }; + return data_file; + } + + std::shared_ptr MakeSomeNullsFile() { + auto data_file = std::make_shared(); + data_file->file_path = "file_2.avro"; + data_file->file_format = FileFormatType::kParquet; + data_file->record_count = 50; + data_file->value_counts = { + {4, 50L}, + {5, 50L}, + {6, 50L}, + {8, 50L}, + }; + data_file->null_value_counts = { + {4, 50L}, + {5, 10L}, + {6, 0L}, + }; + data_file->lower_bounds = { + {5, Literal::String("bbb").Serialize().value()}, + }; + data_file->upper_bounds = { + {5, Literal::String("eee").Serialize().value()}, + }; + return data_file; + } + + std::shared_ptr MakeSomeNullsEqualBoundsFile() { + auto data_file = std::make_shared(); + data_file->file_path = "file_3.avro"; + data_file->file_format = FileFormatType::kParquet; + data_file->record_count = 50; + data_file->value_counts = { + {4, 50L}, + {5, 50L}, + {6, 50L}, + }; + data_file->null_value_counts = { + {4, 50L}, + {5, 10L}, + {6, 0L}, + }; + data_file->lower_bounds = { + {5, Literal::String("bbb").Serialize().value()}, + }; + data_file->upper_bounds = { + {5, Literal::String("bbb").Serialize().value()}, + }; + return data_file; + } + + std::shared_ptr MakeMissingStatsFile() { + auto data_file = std::make_shared(); + data_file->file_path = "missing.parquet"; + data_file->file_format = FileFormatType::kParquet; + data_file->record_count = 50; + return data_file; + } + + std::shared_ptr MakeZeroRecordFile() { + auto data_file = std::make_shared(); + data_file->file_path = "zero.parquet"; + data_file->file_format = FileFormatType::kParquet; + data_file->record_count = 0; + return data_file; + } + + void ExpectShouldRead(const std::shared_ptr& expr, bool expected, + std::shared_ptr file = nullptr, + bool case_sensitive = true) { + auto target = file ? file : file_; + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, + StrictMetricsEvaluator::Make(expr, schema_, case_sensitive)); + auto eval_result = evaluator->Evaluate(*target); + ASSERT_TRUE(eval_result.has_value()); + ASSERT_EQ(eval_result.value(), expected) << expr->ToString(); + } + + std::shared_ptr schema_; + std::shared_ptr file_; + std::shared_ptr file_with_bounds_; + std::shared_ptr file_with_equal_bounds_; +}; + +TEST_F(StrictMetricsEvaluatorMigratedTest, AllNulls) { + ExpectShouldRead(Expressions::NotNull("all_nulls"), false); + ExpectShouldRead(Expressions::NotNull("some_nulls"), false); + ExpectShouldRead(Expressions::NotNull("no_nulls"), true); + ExpectShouldRead(Expressions::NotEqual("all_nulls", Literal::String("a")), true); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, NoNulls) { + ExpectShouldRead(Expressions::IsNull("all_nulls"), true); + ExpectShouldRead(Expressions::IsNull("some_nulls"), false); + ExpectShouldRead(Expressions::IsNull("no_nulls"), false); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, SomeNulls) { + ExpectShouldRead(Expressions::LessThan("some_nulls", Literal::String("ggg")), false, + file_with_bounds_); + ExpectShouldRead(Expressions::LessThanOrEqual("some_nulls", Literal::String("eee")), + false, file_with_bounds_); + ExpectShouldRead(Expressions::GreaterThan("some_nulls", Literal::String("aaa")), false, + file_with_bounds_); + ExpectShouldRead(Expressions::GreaterThanOrEqual("some_nulls", Literal::String("bbb")), + false, file_with_bounds_); + ExpectShouldRead(Expressions::Equal("some_nulls", Literal::String("bbb")), false, + file_with_equal_bounds_); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, IsNaN) { + ExpectShouldRead(Expressions::IsNaN("all_nans"), true); + ExpectShouldRead(Expressions::IsNaN("some_nans"), false); + ExpectShouldRead(Expressions::IsNaN("no_nans"), false); + ExpectShouldRead(Expressions::IsNaN("all_nulls_double"), false); + ExpectShouldRead(Expressions::IsNaN("no_nan_stats"), false); + ExpectShouldRead(Expressions::IsNaN("all_nans_v1_stats"), false); + ExpectShouldRead(Expressions::IsNaN("nan_and_null_only"), false); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, NotNaN) { + ExpectShouldRead(Expressions::NotNaN("all_nans"), false); + ExpectShouldRead(Expressions::NotNaN("some_nans"), false); + ExpectShouldRead(Expressions::NotNaN("no_nans"), true); + ExpectShouldRead(Expressions::NotNaN("all_nulls_double"), true); + ExpectShouldRead(Expressions::NotNaN("no_nan_stats"), false); + ExpectShouldRead(Expressions::NotNaN("all_nans_v1_stats"), false); + ExpectShouldRead(Expressions::NotNaN("nan_and_null_only"), false); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, RequiredColumn) { + ExpectShouldRead(Expressions::NotNull("required"), true); + ExpectShouldRead(Expressions::IsNull("required"), false); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, MissingColumn) { + auto expr = Expressions::LessThan("missing", Literal::Long(5)); + auto evaluator = StrictMetricsEvaluator::Make(expr, schema_, true); + ASSERT_FALSE(evaluator.has_value()); + EXPECT_TRUE(evaluator.error().message.contains("Cannot find field 'missing'")) + << evaluator.error().message; +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, MissingStats) { + auto missing_stats = MakeMissingStatsFile(); + std::vector> expressions = { + Expressions::LessThan("no_stats", Literal::Long(5)), + Expressions::LessThanOrEqual("no_stats", Literal::Long(30)), + Expressions::Equal("no_stats", Literal::Long(70)), + Expressions::GreaterThan("no_stats", Literal::Long(78)), + Expressions::GreaterThanOrEqual("no_stats", Literal::Long(90)), + Expressions::NotEqual("no_stats", Literal::Long(101)), + Expressions::IsNull("no_stats"), + Expressions::NotNull("no_stats"), + Expressions::IsNaN("all_nans"), + Expressions::NotNaN("all_nans"), + }; + for (const auto& expr : expressions) { + ExpectShouldRead(expr, false, missing_stats); + } +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, ZeroRecordFile) { + auto zero_record_file = MakeZeroRecordFile(); + std::vector> expressions = { + Expressions::LessThan("id", Literal::Long(5)), + Expressions::LessThanOrEqual("id", Literal::Long(30)), + Expressions::Equal("id", Literal::Long(70)), + Expressions::GreaterThan("id", Literal::Long(78)), + Expressions::GreaterThanOrEqual("id", Literal::Long(90)), + Expressions::NotEqual("id", Literal::Long(101)), + Expressions::IsNull("some_nulls"), + Expressions::NotNull("some_nulls"), + Expressions::IsNaN("all_nans"), + Expressions::NotNaN("all_nans"), + }; + for (const auto& expr : expressions) { + ExpectShouldRead(expr, true, zero_record_file); + } +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, Not) { + ExpectShouldRead( + Expressions::Not(Expressions::LessThan("id", Literal::Long(kIntMinValue - 25))), + true); + ExpectShouldRead( + Expressions::Not(Expressions::GreaterThan("id", Literal::Long(kIntMinValue - 25))), + false); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, And) { + ExpectShouldRead( + Expressions::And(Expressions::GreaterThan("id", Literal::Long(kIntMinValue - 25)), + Expressions::LessThanOrEqual("id", Literal::Long(kIntMinValue))), + false); + ExpectShouldRead( + Expressions::And( + Expressions::LessThan("id", Literal::Long(kIntMinValue - 25)), + Expressions::GreaterThanOrEqual("id", Literal::Long(kIntMinValue - 30))), + false); + ExpectShouldRead( + Expressions::And( + Expressions::LessThan("id", Literal::Long(kIntMaxValue + 6)), + Expressions::GreaterThanOrEqual("id", Literal::Long(kIntMinValue - 30))), + true); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, Or) { + ExpectShouldRead( + Expressions::Or( + Expressions::LessThan("id", Literal::Long(kIntMinValue - 25)), + Expressions::GreaterThanOrEqual("id", Literal::Long(kIntMaxValue + 1))), + false); + ExpectShouldRead( + Expressions::Or( + Expressions::LessThan("id", Literal::Long(kIntMinValue - 25)), + Expressions::GreaterThanOrEqual("id", Literal::Long(kIntMaxValue - 19))), + false); + ExpectShouldRead( + Expressions::Or(Expressions::LessThan("id", Literal::Long(kIntMinValue - 25)), + Expressions::GreaterThanOrEqual("id", Literal::Long(kIntMinValue))), + true); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, IntegerLt) { + ExpectShouldRead(Expressions::LessThan("id", Literal::Long(kIntMinValue)), false); + ExpectShouldRead(Expressions::LessThan("id", Literal::Long(kIntMinValue + 1)), false); + ExpectShouldRead(Expressions::LessThan("id", Literal::Long(kIntMaxValue)), false); + ExpectShouldRead(Expressions::LessThan("id", Literal::Long(kIntMaxValue + 1)), true); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, IntegerLtEq) { + ExpectShouldRead(Expressions::LessThanOrEqual("id", Literal::Long(kIntMinValue - 1)), + false); + ExpectShouldRead(Expressions::LessThanOrEqual("id", Literal::Long(kIntMinValue)), + false); + ExpectShouldRead(Expressions::LessThanOrEqual("id", Literal::Long(kIntMaxValue)), true); + ExpectShouldRead(Expressions::LessThanOrEqual("id", Literal::Long(kIntMaxValue + 1)), + true); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, IntegerGt) { + ExpectShouldRead(Expressions::GreaterThan("id", Literal::Long(kIntMaxValue)), false); + ExpectShouldRead(Expressions::GreaterThan("id", Literal::Long(kIntMaxValue - 1)), + false); + ExpectShouldRead(Expressions::GreaterThan("id", Literal::Long(kIntMinValue)), false); + ExpectShouldRead(Expressions::GreaterThan("id", Literal::Long(kIntMinValue - 1)), true); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, IntegerGtEq) { + ExpectShouldRead(Expressions::GreaterThanOrEqual("id", Literal::Long(kIntMaxValue + 1)), + false); + ExpectShouldRead(Expressions::GreaterThanOrEqual("id", Literal::Long(kIntMaxValue)), + false); + ExpectShouldRead(Expressions::GreaterThanOrEqual("id", Literal::Long(kIntMinValue + 1)), + false); + ExpectShouldRead(Expressions::GreaterThanOrEqual("id", Literal::Long(kIntMinValue)), + true); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, IntegerEq) { + ExpectShouldRead(Expressions::Equal("id", Literal::Long(kIntMinValue - 25)), false); + ExpectShouldRead(Expressions::Equal("id", Literal::Long(kIntMinValue)), false); + ExpectShouldRead(Expressions::Equal("id", Literal::Long(kIntMaxValue - 4)), false); + ExpectShouldRead(Expressions::Equal("id", Literal::Long(kIntMaxValue)), false); + ExpectShouldRead(Expressions::Equal("id", Literal::Long(kIntMaxValue + 1)), false); + ExpectShouldRead(Expressions::Equal("always_5", Literal::Long(kIntMinValue - 25)), + true); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, IntegerNotEq) { + ExpectShouldRead(Expressions::NotEqual("id", Literal::Long(kIntMinValue - 25)), true); + ExpectShouldRead(Expressions::NotEqual("id", Literal::Long(kIntMinValue - 1)), true); + ExpectShouldRead(Expressions::NotEqual("id", Literal::Long(kIntMinValue)), false); + ExpectShouldRead(Expressions::NotEqual("id", Literal::Long(kIntMaxValue - 4)), false); + ExpectShouldRead(Expressions::NotEqual("id", Literal::Long(kIntMaxValue)), false); + ExpectShouldRead(Expressions::NotEqual("id", Literal::Long(kIntMaxValue + 1)), true); + ExpectShouldRead(Expressions::NotEqual("id", Literal::Long(kIntMaxValue + 6)), true); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, IntegerNotEqRewritten) { + ExpectShouldRead( + Expressions::Not(Expressions::Equal("id", Literal::Long(kIntMinValue - 25))), true); + ExpectShouldRead( + Expressions::Not(Expressions::Equal("id", Literal::Long(kIntMinValue - 1))), true); + ExpectShouldRead( + Expressions::Not(Expressions::Equal("id", Literal::Long(kIntMinValue))), false); + ExpectShouldRead( + Expressions::Not(Expressions::Equal("id", Literal::Long(kIntMaxValue - 4))), false); + ExpectShouldRead( + Expressions::Not(Expressions::Equal("id", Literal::Long(kIntMaxValue))), false); + ExpectShouldRead( + Expressions::Not(Expressions::Equal("id", Literal::Long(kIntMaxValue + 1))), true); + ExpectShouldRead( + Expressions::Not(Expressions::Equal("id", Literal::Long(kIntMaxValue + 6))), true); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, IntegerIn) { + ExpectShouldRead(Expressions::In("id", {Literal::Long(kIntMinValue - 25), + Literal::Long(kIntMinValue - 24)}), + false); + ExpectShouldRead(Expressions::In("id", {Literal::Long(kIntMinValue - 1), + Literal::Long(kIntMinValue)}), + false); + ExpectShouldRead(Expressions::In("id", {Literal::Long(kIntMaxValue - 4), + Literal::Long(kIntMaxValue - 3)}), + false); + ExpectShouldRead(Expressions::In("id", {Literal::Long(kIntMaxValue), + Literal::Long(kIntMaxValue + 1)}), + false); + ExpectShouldRead(Expressions::In("id", {Literal::Long(kIntMaxValue + 1), + Literal::Long(kIntMaxValue + 2)}), + false); + ExpectShouldRead(Expressions::In("always_5", {Literal::Long(5), Literal::Long(6)}), + true); + ExpectShouldRead( + Expressions::In("all_nulls", {Literal::String("abc"), Literal::String("def")}), + false); + ExpectShouldRead( + Expressions::In("some_nulls", {Literal::String("abc"), Literal::String("def")}), + false, file_with_equal_bounds_); + ExpectShouldRead( + Expressions::In("no_nulls", {Literal::String("abc"), Literal::String("def")}), + false); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, IntegerNotIn) { + ExpectShouldRead(Expressions::NotIn("id", {Literal::Long(kIntMinValue - 25), + Literal::Long(kIntMinValue - 24)}), + true); + ExpectShouldRead(Expressions::NotIn("id", {Literal::Long(kIntMinValue - 1), + Literal::Long(kIntMinValue)}), + false); + ExpectShouldRead(Expressions::NotIn("id", {Literal::Long(kIntMaxValue - 4), + Literal::Long(kIntMaxValue - 3)}), + false); + ExpectShouldRead(Expressions::NotIn("id", {Literal::Long(kIntMaxValue), + Literal::Long(kIntMaxValue + 1)}), + false); + ExpectShouldRead(Expressions::NotIn("id", {Literal::Long(kIntMaxValue + 1), + Literal::Long(kIntMaxValue + 2)}), + true); + ExpectShouldRead(Expressions::NotIn("always_5", {Literal::Long(5), Literal::Long(6)}), + false); + ExpectShouldRead( + Expressions::NotIn("all_nulls", {Literal::String("abc"), Literal::String("def")}), + true); + ExpectShouldRead( + Expressions::NotIn("some_nulls", {Literal::String("abc"), Literal::String("def")}), + true, file_with_equal_bounds_); + ExpectShouldRead( + Expressions::NotIn("no_nulls", {Literal::String("abc"), Literal::String("def")}), + false); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, EvaluateOnNestedColumnWithoutStats) { + ExpectShouldRead(Expressions::GreaterThanOrEqual("struct.nested_col_no_stats", + Literal::Long(kIntMinValue)), + false); + ExpectShouldRead(Expressions::LessThanOrEqual("struct.nested_col_no_stats", + Literal::Long(kIntMaxValue)), + false); + ExpectShouldRead(Expressions::IsNull("struct.nested_col_no_stats"), false); + ExpectShouldRead(Expressions::NotNull("struct.nested_col_no_stats"), false); +} + +TEST_F(StrictMetricsEvaluatorMigratedTest, EvaluateOnNestedColumnWithStats) { + ExpectShouldRead(Expressions::GreaterThanOrEqual("struct.nested_col_with_stats", + Literal::Long(kIntMinValue)), + false); + ExpectShouldRead(Expressions::LessThanOrEqual("struct.nested_col_with_stats", + Literal::Long(kIntMaxValue)), + false); + ExpectShouldRead(Expressions::IsNull("struct.nested_col_with_stats"), false); + ExpectShouldRead(Expressions::NotNull("struct.nested_col_with_stats"), false); +} + +} // namespace iceberg diff --git a/src/iceberg/test/table_metadata_builder_test.cc b/src/iceberg/test/table_metadata_builder_test.cc index ff41ae18c..a1e46615f 100644 --- a/src/iceberg/test/table_metadata_builder_test.cc +++ b/src/iceberg/test/table_metadata_builder_test.cc @@ -19,20 +19,33 @@ #include #include +#include #include #include "iceberg/partition_spec.h" +#include "iceberg/schema.h" #include "iceberg/snapshot.h" +#include "iceberg/sort_field.h" #include "iceberg/sort_order.h" #include "iceberg/table_metadata.h" #include "iceberg/table_update.h" #include "iceberg/test/matchers.h" +#include "iceberg/transform.h" +#include "iceberg/type.h" namespace iceberg { namespace { +// Helper function to create a simple schema for testing +std::shared_ptr CreateTestSchema() { + auto field1 = SchemaField::MakeRequired(1, "id", int32()); + auto field2 = SchemaField::MakeRequired(2, "data", string()); + auto field3 = SchemaField::MakeRequired(3, "ts", timestamp()); + return std::make_shared(std::vector{field1, field2, field3}, 0); +} + // Helper function to create base metadata for tests std::unique_ptr CreateBaseMetadata() { auto metadata = std::make_unique(); @@ -41,11 +54,14 @@ std::unique_ptr CreateBaseMetadata() { metadata->location = "s3://bucket/test"; metadata->last_sequence_number = 0; metadata->last_updated_ms = TimePointMs{std::chrono::milliseconds(1000)}; - metadata->last_column_id = 0; + metadata->last_column_id = 3; + metadata->current_schema_id = 0; + metadata->schemas.push_back(CreateTestSchema()); metadata->default_spec_id = PartitionSpec::kInitialSpecId; metadata->last_partition_id = 0; metadata->current_snapshot_id = Snapshot::kInvalidSnapshotId; metadata->default_sort_order_id = SortOrder::kInitialSortOrderId; + metadata->sort_orders.push_back(SortOrder::Unsorted()); metadata->next_row_id = TableMetadata::kInitialRowId; return metadata; } @@ -82,7 +98,7 @@ TEST(TableMetadataBuilderTest, BuildFromExisting) { EXPECT_EQ(metadata->location, "s3://bucket/test"); } -// Test AssignUUID method +// Test AssignUUID TEST(TableMetadataBuilderTest, AssignUUID) { // Assign UUID for new table auto builder = TableMetadataBuilder::BuildFromEmpty(2); @@ -174,17 +190,149 @@ TEST(TableMetadataBuilderTest, UpgradeFormatVersion) { EXPECT_THAT(builder->Build(), HasErrorMessage("Cannot downgrade")); } -// Test applying TableUpdate to builder -TEST(TableMetadataBuilderTest, ApplyUpdate) { - // Apply AssignUUID update - auto builder = TableMetadataBuilder::BuildFromEmpty(2); - table::AssignUUID update("apply-uuid"); - update.ApplyTo(*builder); - // TODO(Li Feiyang): Add more update and `apply` once other build methods are - // implemented +// Test AddSortOrder +TEST(TableMetadataBuilderTest, AddSortOrderBasic) { + auto base = CreateBaseMetadata(); + auto builder = TableMetadataBuilder::BuildFrom(base.get()); + auto schema = CreateTestSchema(); + // 1. Add unsorted - should reuse existing unsorted order + builder->AddSortOrder(SortOrder::Unsorted()); ICEBERG_UNWRAP_OR_FAIL(auto metadata, builder->Build()); - EXPECT_EQ(metadata->table_uuid, "apply-uuid"); + ASSERT_EQ(metadata->sort_orders.size(), 1); + EXPECT_TRUE(metadata->sort_orders[0]->is_unsorted()); + + // 2. Add basic sort order + builder = TableMetadataBuilder::BuildFrom(base.get()); + SortField field1(1, Transform::Identity(), SortDirection::kAscending, + NullOrder::kFirst); + ICEBERG_UNWRAP_OR_FAIL(auto order1, + SortOrder::Make(*schema, 1, std::vector{field1})); + builder->AddSortOrder(std::move(order1)); + ICEBERG_UNWRAP_OR_FAIL(metadata, builder->Build()); + ASSERT_EQ(metadata->sort_orders.size(), 2); + EXPECT_EQ(metadata->sort_orders[1]->order_id(), 1); + + // 3. Add duplicate - should be idempotent + builder = TableMetadataBuilder::BuildFrom(base.get()); + ICEBERG_UNWRAP_OR_FAIL(auto order2, + SortOrder::Make(*schema, 1, std::vector{field1})); + ICEBERG_UNWRAP_OR_FAIL(auto order3, + SortOrder::Make(*schema, 1, std::vector{field1})); + builder->AddSortOrder(std::move(order2)); + builder->AddSortOrder(std::move(order3)); // Duplicate + ICEBERG_UNWRAP_OR_FAIL(metadata, builder->Build()); + ASSERT_EQ(metadata->sort_orders.size(), 2); // Only one added + + // 4. Add multiple different orders + verify ID reassignment + builder = TableMetadataBuilder::BuildFrom(base.get()); + SortField field2(2, Transform::Identity(), SortDirection::kDescending, + NullOrder::kLast); + // User provides ID=99, Builder should reassign to ID=1 + ICEBERG_UNWRAP_OR_FAIL(auto order4, + SortOrder::Make(*schema, 99, std::vector{field1})); + ICEBERG_UNWRAP_OR_FAIL( + auto order5, SortOrder::Make(*schema, 2, std::vector{field1, field2})); + builder->AddSortOrder(std::move(order4)); + builder->AddSortOrder(std::move(order5)); + ICEBERG_UNWRAP_OR_FAIL(metadata, builder->Build()); + ASSERT_EQ(metadata->sort_orders.size(), 3); + EXPECT_EQ(metadata->sort_orders[1]->order_id(), 1); // Reassigned from 99 + EXPECT_EQ(metadata->sort_orders[2]->order_id(), 2); +} + +TEST(TableMetadataBuilderTest, AddSortOrderInvalid) { + auto base = CreateBaseMetadata(); + auto schema = CreateTestSchema(); + + // 1. Invalid field ID + auto builder = TableMetadataBuilder::BuildFrom(base.get()); + SortField invalid_field(999, Transform::Identity(), SortDirection::kAscending, + NullOrder::kFirst); + ICEBERG_UNWRAP_OR_FAIL(auto order1, + SortOrder::Make(1, std::vector{invalid_field})); + builder->AddSortOrder(std::move(order1)); + ASSERT_THAT(builder->Build(), IsError(ErrorKind::kValidationFailed)); + ASSERT_THAT(builder->Build(), HasErrorMessage("Cannot find source column")); + + // 2. Invalid transform (Day transform on string type) + builder = TableMetadataBuilder::BuildFrom(base.get()); + SortField invalid_transform(2, Transform::Day(), SortDirection::kAscending, + NullOrder::kFirst); + ICEBERG_UNWRAP_OR_FAIL(auto order2, + SortOrder::Make(1, std::vector{invalid_transform})); + builder->AddSortOrder(std::move(order2)); + ASSERT_THAT(builder->Build(), IsError(ErrorKind::kValidationFailed)); + ASSERT_THAT(builder->Build(), HasErrorMessage("Invalid source type")); + + // 3. Without schema + builder = TableMetadataBuilder::BuildFromEmpty(2); + builder->AssignUUID("test-uuid"); + SortField field(1, Transform::Identity(), SortDirection::kAscending, NullOrder::kFirst); + ICEBERG_UNWRAP_OR_FAIL(auto order3, + SortOrder::Make(*schema, 1, std::vector{field})); + builder->AddSortOrder(std::move(order3)); + ASSERT_THAT(builder->Build(), IsError(ErrorKind::kValidationFailed)); + ASSERT_THAT(builder->Build(), HasErrorMessage("Schema with ID")); +} + +// Test SetDefaultSortOrder +TEST(TableMetadataBuilderTest, SetDefaultSortOrderBasic) { + auto base = CreateBaseMetadata(); + auto schema = CreateTestSchema(); + + // 1. Set default sort order by SortOrder object + auto builder = TableMetadataBuilder::BuildFrom(base.get()); + SortField field1(1, Transform::Identity(), SortDirection::kAscending, + NullOrder::kFirst); + ICEBERG_UNWRAP_OR_FAIL(auto order1_unique, + SortOrder::Make(*schema, 1, std::vector{field1})); + auto order1 = std::shared_ptr(std::move(order1_unique)); + builder->SetDefaultSortOrder(order1); + ICEBERG_UNWRAP_OR_FAIL(auto metadata, builder->Build()); + ASSERT_EQ(metadata->sort_orders.size(), 2); + EXPECT_EQ(metadata->default_sort_order_id, 1); + EXPECT_EQ(metadata->sort_orders[1]->order_id(), 1); + + // 2. Set default sort order by order ID + builder = TableMetadataBuilder::BuildFrom(base.get()); + SortField field2(1, Transform::Identity(), SortDirection::kAscending, + NullOrder::kFirst); + ICEBERG_UNWRAP_OR_FAIL(auto order2_unique, + SortOrder::Make(*schema, 1, std::vector{field2})); + auto order2 = std::shared_ptr(std::move(order2_unique)); + builder->AddSortOrder(order2); + builder->SetDefaultSortOrder(1); + ICEBERG_UNWRAP_OR_FAIL(metadata, builder->Build()); + EXPECT_EQ(metadata->default_sort_order_id, 1); + + // 3. Set default sort order using -1 (last added) + builder = TableMetadataBuilder::BuildFrom(base.get()); + SortField field3(2, Transform::Identity(), SortDirection::kDescending, + NullOrder::kLast); + ICEBERG_UNWRAP_OR_FAIL(auto order3_unique, + SortOrder::Make(*schema, 1, std::vector{field3})); + auto order3 = std::shared_ptr(std::move(order3_unique)); + builder->AddSortOrder(order3); + builder->SetDefaultSortOrder(-1); // Use last added + ICEBERG_UNWRAP_OR_FAIL(metadata, builder->Build()); + EXPECT_EQ(metadata->default_sort_order_id, 1); + + // 4. Setting same order is no-op + builder = TableMetadataBuilder::BuildFrom(base.get()); + builder->SetDefaultSortOrder(0); + ICEBERG_UNWRAP_OR_FAIL(metadata, builder->Build()); + EXPECT_EQ(metadata->default_sort_order_id, 0); +} + +TEST(TableMetadataBuilderTest, SetDefaultSortOrderInvalid) { + auto base = CreateBaseMetadata(); + + // Try to use -1 (last added) when no order has been added + auto builder = TableMetadataBuilder::BuildFrom(base.get()); + builder->SetDefaultSortOrder(-1); + ASSERT_THAT(builder->Build(), IsError(ErrorKind::kValidationFailed)); + ASSERT_THAT(builder->Build(), HasErrorMessage("no sort order has been added")); } } // namespace iceberg diff --git a/src/iceberg/test/table_requirements_test.cc b/src/iceberg/test/table_requirements_test.cc index 441e72d41..041b44dd1 100644 --- a/src/iceberg/test/table_requirements_test.cc +++ b/src/iceberg/test/table_requirements_test.cc @@ -20,18 +20,21 @@ #include "iceberg/table_requirements.h" #include +#include #include #include #include #include "iceberg/partition_spec.h" +#include "iceberg/schema.h" #include "iceberg/snapshot.h" #include "iceberg/sort_order.h" #include "iceberg/table_metadata.h" #include "iceberg/table_requirement.h" #include "iceberg/table_update.h" #include "iceberg/test/matchers.h" +#include "iceberg/type.h" namespace iceberg { @@ -47,16 +50,43 @@ std::unique_ptr CreateBaseMetadata( metadata->last_sequence_number = 0; metadata->last_updated_ms = TimePointMs{std::chrono::milliseconds(1000)}; metadata->last_column_id = 0; + metadata->current_schema_id = Schema::kInitialSchemaId; metadata->default_spec_id = PartitionSpec::kInitialSpecId; metadata->last_partition_id = 0; metadata->current_snapshot_id = Snapshot::kInvalidSnapshotId; - metadata->default_sort_order_id = SortOrder::kInitialSortOrderId; + metadata->default_sort_order_id = SortOrder::kUnsortedOrderId; metadata->next_row_id = TableMetadata::kInitialRowId; return metadata; } +// Helper function to create a simple schema for tests +std::shared_ptr CreateTestSchema(int32_t schema_id = 0) { + std::vector fields; + fields.emplace_back(SchemaField::MakeRequired(1, "id", int32())); + return std::make_shared(std::move(fields), schema_id); +} + +// Helper function to count requirements of a specific type +template +int CountRequirementsOfType( + const std::vector>& requirements) { + return std::ranges::count_if(requirements, [](const auto& req) { + return dynamic_cast(req.get()) != nullptr; + }); +} + +// Helper function to add a branch to metadata +void AddBranch(TableMetadata& metadata, const std::string& name, int64_t snapshot_id) { + auto ref = std::make_shared(); + ref->snapshot_id = snapshot_id; + ref->retention = SnapshotRef::Branch{}; + metadata.refs[name] = ref; +} + } // namespace +// Empty Updates Tests + TEST(TableRequirementsTest, EmptyUpdatesForCreateTable) { std::vector> updates; @@ -104,6 +134,8 @@ TEST(TableRequirementsTest, EmptyUpdatesForReplaceTable) { EXPECT_EQ(assert_uuid->uuid(), metadata->table_uuid); } +// Table Existence Tests + TEST(TableRequirementsTest, TableAlreadyExists) { std::vector> updates; @@ -134,7 +166,8 @@ TEST(TableRequirementsTest, TableDoesNotExist) { EXPECT_THAT(status, IsOk()); } -// Test for AssignUUID update +// AssignUUID Tests + TEST(TableRequirementsTest, AssignUUID) { auto metadata = CreateBaseMetadata("original-uuid"); std::vector> updates; @@ -207,4 +240,875 @@ TEST(TableRequirementsTest, AssignUUIDForReplaceTable) { EXPECT_THAT(status, IsOk()); } +// UpgradeFormatVersion Tests + +TEST(TableRequirementsTest, UpgradeFormatVersion) { + auto metadata = CreateBaseMetadata(); + std::vector> updates; + updates.push_back(std::make_unique(2)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // UpgradeFormatVersion doesn't add additional requirements + ASSERT_EQ(requirements.size(), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +// AddSchema Tests + +TEST(TableRequirementsTest, AddSchema) { + auto metadata = CreateBaseMetadata(); + metadata->last_column_id = 1; + std::vector> updates; + + auto schema = CreateTestSchema(); + // Add multiple AddSchema updates + updates.push_back(std::make_unique(schema, 1)); + updates.push_back(std::make_unique(schema, 1)); + updates.push_back(std::make_unique(schema, 1)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Should have AssertUUID + AssertLastAssignedFieldId (deduplicated) + ASSERT_EQ(requirements.size(), 2); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Verify the last assigned field ID value + auto* assert_field_id = + dynamic_cast(requirements[1].get()); + ASSERT_NE(assert_field_id, nullptr); + EXPECT_EQ(assert_field_id->last_assigned_field_id(), 1); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +TEST(TableRequirementsTest, AddSchemaFailure) { + auto metadata = CreateBaseMetadata(); + metadata->last_column_id = 2; + + std::vector> updates; + auto schema = CreateTestSchema(); + updates.push_back(std::make_unique(schema, 2)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + + // Create updated metadata with different last_column_id + auto updated = CreateBaseMetadata(); + updated->last_column_id = 3; + + // Find and validate the AssertLastAssignedFieldId requirement + for (const auto& req : requirements) { + if (dynamic_cast(req.get()) != nullptr) { + auto status = req->Validate(updated.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("last assigned field ID does not match")); + break; + } + } +} + +// SetCurrentSchema Tests + +TEST(TableRequirementsTest, SetCurrentSchema) { + auto metadata = CreateBaseMetadata(); + metadata->current_schema_id = 3; + std::vector> updates; + + // Add multiple SetCurrentSchema updates + updates.push_back(std::make_unique(3)); + updates.push_back(std::make_unique(4)); + updates.push_back(std::make_unique(5)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Should have AssertUUID + AssertCurrentSchemaID (deduplicated) + ASSERT_EQ(requirements.size(), 2); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Verify the current schema ID value + auto* assert_schema_id = + dynamic_cast(requirements[1].get()); + ASSERT_NE(assert_schema_id, nullptr); + EXPECT_EQ(assert_schema_id->schema_id(), 3); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +TEST(TableRequirementsTest, SetCurrentSchemaFailure) { + auto metadata = CreateBaseMetadata(); + metadata->current_schema_id = 3; + + std::vector> updates; + updates.push_back(std::make_unique(3)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + + // Create updated metadata with different current_schema_id + auto updated = CreateBaseMetadata(); + updated->current_schema_id = 4; + + // Find and validate the AssertCurrentSchemaID requirement + for (const auto& req : requirements) { + if (dynamic_cast(req.get()) != nullptr) { + auto status = req->Validate(updated.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("current schema ID does not match")); + break; + } + } +} + +// AddPartitionSpec Tests + +TEST(TableRequirementsTest, AddPartitionSpec) { + auto metadata = CreateBaseMetadata(); + metadata->last_partition_id = 3; + + std::vector> updates; + updates.push_back( + std::make_unique(PartitionSpec::Unpartitioned())); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Should have AssertUUID + AssertLastAssignedPartitionId + ASSERT_EQ(requirements.size(), 2); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), + 1); + + // Verify the last assigned partition ID value + auto* assert_partition_id = + dynamic_cast(requirements[1].get()); + ASSERT_NE(assert_partition_id, nullptr); + EXPECT_EQ(assert_partition_id->last_assigned_partition_id(), 3); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +TEST(TableRequirementsTest, AddPartitionSpecFailure) { + auto metadata = CreateBaseMetadata(); + metadata->last_partition_id = 3; + + std::vector> updates; + updates.push_back( + std::make_unique(PartitionSpec::Unpartitioned())); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + + // Create updated metadata with different last_partition_id + auto updated = CreateBaseMetadata(); + updated->last_partition_id = 4; + + // Find and validate the AssertLastAssignedPartitionId requirement + for (const auto& req : requirements) { + if (dynamic_cast(req.get()) != nullptr) { + auto status = req->Validate(updated.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("last assigned partition ID does not match")); + break; + } + } +} + +// SetDefaultPartitionSpec Tests + +TEST(TableRequirementsTest, SetDefaultPartitionSpec) { + auto metadata = CreateBaseMetadata(); + metadata->default_spec_id = 3; + + std::vector> updates; + // Add multiple SetDefaultPartitionSpec updates + updates.push_back(std::make_unique(3)); + updates.push_back(std::make_unique(4)); + updates.push_back(std::make_unique(5)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Should have AssertUUID + AssertDefaultSpecID (deduplicated) + ASSERT_EQ(requirements.size(), 2); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Verify the default spec ID value + auto* assert_spec_id = dynamic_cast(requirements[1].get()); + ASSERT_NE(assert_spec_id, nullptr); + EXPECT_EQ(assert_spec_id->spec_id(), 3); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +TEST(TableRequirementsTest, SetDefaultPartitionSpecFailure) { + auto metadata = CreateBaseMetadata(); + metadata->default_spec_id = PartitionSpec::kInitialSpecId; + + std::vector> updates; + updates.push_back( + std::make_unique(PartitionSpec::kInitialSpecId)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + + // Create updated metadata with different default_spec_id + auto updated = CreateBaseMetadata(); + updated->default_spec_id = PartitionSpec::kInitialSpecId + 1; + + // Find and validate the AssertDefaultSpecID requirement + for (const auto& req : requirements) { + if (dynamic_cast(req.get()) != nullptr) { + auto status = req->Validate(updated.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("default partition spec changed")); + break; + } + } +} + +// RemovePartitionSpecs Tests + +TEST(TableRequirementsTest, RemovePartitionSpecs) { + auto metadata = CreateBaseMetadata(); + metadata->default_spec_id = 3; + + std::vector> updates; + updates.push_back( + std::make_unique(std::vector{1, 2})); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Should have AssertUUID + AssertDefaultSpecID + ASSERT_EQ(requirements.size(), 2); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Verify the default spec ID value + auto* assert_spec_id = dynamic_cast(requirements[1].get()); + ASSERT_NE(assert_spec_id, nullptr); + EXPECT_EQ(assert_spec_id->spec_id(), 3); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +TEST(TableRequirementsTest, RemovePartitionSpecsWithBranch) { + auto metadata = CreateBaseMetadata(); + metadata->default_spec_id = 3; + AddBranch(*metadata, "branch", 42); + + std::vector> updates; + updates.push_back( + std::make_unique(std::vector{1, 2})); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Should have AssertUUID + AssertDefaultSpecID + AssertRefSnapshotID + ASSERT_EQ(requirements.size(), 3); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +TEST(TableRequirementsTest, RemovePartitionSpecsWithSpecChangedFailure) { + auto metadata = CreateBaseMetadata(); + metadata->default_spec_id = 3; + + std::vector> updates; + updates.push_back( + std::make_unique(std::vector{1, 2})); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + + // Create updated metadata with different default_spec_id + auto updated = CreateBaseMetadata(); + updated->default_spec_id = 4; + + // Find and validate the AssertDefaultSpecID requirement + for (const auto& req : requirements) { + if (dynamic_cast(req.get()) != nullptr) { + auto status = req->Validate(updated.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("default partition spec changed")); + break; + } + } +} + +TEST(TableRequirementsTest, RemovePartitionSpecsWithBranchChangedFailure) { + auto metadata = CreateBaseMetadata(); + metadata->default_spec_id = 3; + AddBranch(*metadata, "test", 42); + + std::vector> updates; + updates.push_back( + std::make_unique(std::vector{1, 2})); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + + // Create updated metadata with changed branch + auto updated = CreateBaseMetadata(); + updated->default_spec_id = 3; + AddBranch(*updated, "test", 43); + + // Find and validate the AssertRefSnapshotID requirement + for (const auto& req : requirements) { + if (dynamic_cast(req.get()) != nullptr) { + auto status = req->Validate(updated.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("has changed")); + break; + } + } +} + +// RemoveSchemas Tests + +TEST(TableRequirementsTest, RemoveSchemas) { + auto metadata = CreateBaseMetadata(); + metadata->current_schema_id = 3; + + std::vector> updates; + updates.push_back(std::make_unique(std::vector{1, 2})); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Should have AssertUUID + AssertCurrentSchemaID + ASSERT_EQ(requirements.size(), 2); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Verify the current schema ID value + auto* assert_schema_id = + dynamic_cast(requirements[1].get()); + ASSERT_NE(assert_schema_id, nullptr); + EXPECT_EQ(assert_schema_id->schema_id(), 3); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +TEST(TableRequirementsTest, RemoveSchemasWithBranch) { + auto metadata = CreateBaseMetadata(); + metadata->current_schema_id = 3; + AddBranch(*metadata, "branch", 42); + + std::vector> updates; + updates.push_back(std::make_unique(std::vector{1, 2})); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Should have AssertUUID + AssertCurrentSchemaID + AssertRefSnapshotID + ASSERT_EQ(requirements.size(), 3); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +TEST(TableRequirementsTest, RemoveSchemasWithSchemaChangedFailure) { + auto metadata = CreateBaseMetadata(); + metadata->current_schema_id = 3; + + std::vector> updates; + updates.push_back(std::make_unique(std::vector{1, 2})); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + + // Create updated metadata with different current_schema_id + auto updated = CreateBaseMetadata(); + updated->current_schema_id = 4; + + // Find and validate the AssertCurrentSchemaID requirement + for (const auto& req : requirements) { + if (dynamic_cast(req.get()) != nullptr) { + auto status = req->Validate(updated.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("current schema ID does not match")); + break; + } + } +} + +TEST(TableRequirementsTest, RemoveSchemasWithBranchChangedFailure) { + auto metadata = CreateBaseMetadata(); + metadata->current_schema_id = 3; + AddBranch(*metadata, "test", 42); + + std::vector> updates; + updates.push_back(std::make_unique(std::vector{1, 2})); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + + // Create updated metadata with changed branch + auto updated = CreateBaseMetadata(); + updated->current_schema_id = 3; + AddBranch(*updated, "test", 43); + + // Find and validate the AssertRefSnapshotID requirement + for (const auto& req : requirements) { + if (dynamic_cast(req.get()) != nullptr) { + auto status = req->Validate(updated.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("has changed")); + break; + } + } +} + +// AddSortOrder Tests + +TEST(TableRequirementsTest, AddSortOrder) { + auto metadata = CreateBaseMetadata(); + std::vector> updates; + + updates.push_back(std::make_unique(SortOrder::Unsorted())); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // AddSortOrder doesn't add additional requirements + ASSERT_EQ(requirements.size(), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +// SetDefaultSortOrder Tests + +TEST(TableRequirementsTest, SetDefaultSortOrder) { + auto metadata = CreateBaseMetadata(); + metadata->default_sort_order_id = 3; + + std::vector> updates; + // Add multiple SetDefaultSortOrder updates + updates.push_back(std::make_unique(3)); + updates.push_back(std::make_unique(4)); + updates.push_back(std::make_unique(5)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Should have AssertUUID + AssertDefaultSortOrderID (deduplicated) + ASSERT_EQ(requirements.size(), 2); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Verify the default sort order ID value + auto* assert_sort_order_id = + dynamic_cast(requirements[1].get()); + ASSERT_NE(assert_sort_order_id, nullptr); + EXPECT_EQ(assert_sort_order_id->sort_order_id(), 3); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +TEST(TableRequirementsTest, SetDefaultSortOrderFailure) { + auto metadata = CreateBaseMetadata(); + metadata->default_sort_order_id = SortOrder::kUnsortedOrderId; + + std::vector> updates; + updates.push_back( + std::make_unique(SortOrder::kUnsortedOrderId)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + + // Create updated metadata with different default_sort_order_id + auto updated = CreateBaseMetadata(); + updated->default_sort_order_id = SortOrder::kUnsortedOrderId + 1; + + // Find and validate the AssertDefaultSortOrderID requirement + for (const auto& req : requirements) { + if (dynamic_cast(req.get()) != nullptr) { + auto status = req->Validate(updated.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("default sort order changed")); + break; + } + } +} + +// AddSnapshot Tests + +TEST(TableRequirementsTest, AddSnapshot) { + auto metadata = CreateBaseMetadata(); + + std::vector> updates; + auto snapshot = std::make_shared(); + snapshot->snapshot_id = 1; + snapshot->sequence_number = 1; + snapshot->timestamp_ms = TimePointMs{std::chrono::milliseconds(1000)}; + snapshot->manifest_list = "s3://bucket/manifest_list"; + updates.push_back(std::make_unique(snapshot)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + ASSERT_EQ(requirements.size(), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +// RemoveSnapshots Tests + +TEST(TableRequirementsTest, RemoveSnapshots) { + auto metadata = CreateBaseMetadata(); + + std::vector> updates; + updates.push_back(std::make_unique(std::vector{0})); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + ASSERT_EQ(requirements.size(), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +// SetSnapshotRef Tests + +TEST(TableRequirementsTest, SetSnapshotRef) { + constexpr int64_t kSnapshotId = 14; + const std::string kRefName = "branch"; + + auto metadata = CreateBaseMetadata(); + AddBranch(*metadata, kRefName, kSnapshotId); + + // Multiple updates to same ref should deduplicate + std::vector> updates; + updates.push_back(std::make_unique(kRefName, kSnapshotId, + SnapshotRefType::kBranch)); + updates.push_back(std::make_unique(kRefName, kSnapshotId + 1, + SnapshotRefType::kBranch)); + updates.push_back(std::make_unique(kRefName, kSnapshotId + 2, + SnapshotRefType::kBranch)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } + + ASSERT_EQ(requirements.size(), 2); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + auto* assert_ref = dynamic_cast(requirements[1].get()); + ASSERT_NE(assert_ref, nullptr); + EXPECT_EQ(assert_ref->snapshot_id(), kSnapshotId); + EXPECT_EQ(assert_ref->ref_name(), kRefName); +} + +// RemoveSnapshotRef Tests + +TEST(TableRequirementsTest, RemoveSnapshotRef) { + auto metadata = CreateBaseMetadata(); + + std::vector> updates; + updates.push_back(std::make_unique("branch")); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + ASSERT_EQ(requirements.size(), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +// SetAndRemoveProperties Tests + +TEST(TableRequirementsTest, SetProperties) { + auto metadata = CreateBaseMetadata(); + std::vector> updates; + + std::unordered_map props; + props["test"] = "value"; + updates.push_back(std::make_unique(props)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // SetProperties doesn't add additional requirements + ASSERT_EQ(requirements.size(), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +TEST(TableRequirementsTest, RemoveProperties) { + auto metadata = CreateBaseMetadata(); + std::vector> updates; + + updates.push_back( + std::make_unique(std::vector{"test"})); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // RemoveProperties doesn't add additional requirements + ASSERT_EQ(requirements.size(), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +// SetLocation Tests + +TEST(TableRequirementsTest, SetLocation) { + auto metadata = CreateBaseMetadata(); + std::vector> updates; + + updates.push_back(std::make_unique("s3://new-bucket/test")); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // SetLocation doesn't add additional requirements + ASSERT_EQ(requirements.size(), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + + // Validate against base metadata + for (const auto& req : requirements) { + EXPECT_THAT(req->Validate(metadata.get()), IsOk()); + } +} + +// AssertRefSnapshotID Tests + +TEST(TableRequirementsTest, AssertRefSnapshotIDSuccess) { + auto metadata = CreateBaseMetadata(); + AddBranch(*metadata, "branch", 14); + + table::AssertRefSnapshotID requirement("branch", 14); + auto status = requirement.Validate(metadata.get()); + EXPECT_THAT(status, IsOk()); +} + +TEST(TableRequirementsTest, AssertRefSnapshotIDCreatedConcurrently) { + auto metadata = CreateBaseMetadata(); + AddBranch(*metadata, "random_branch", 14); + + // Requirement expects ref doesn't exist (nullopt snapshot_id) + table::AssertRefSnapshotID requirement("random_branch", std::nullopt); + auto status = requirement.Validate(metadata.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("was created concurrently")); +} + +TEST(TableRequirementsTest, AssertRefSnapshotIDMissing) { + auto metadata = CreateBaseMetadata(); + // No branch added + + // Requirement expects a snapshot ID that doesn't exist + table::AssertRefSnapshotID requirement("random_branch", 14); + auto status = requirement.Validate(metadata.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("is missing")); +} + +TEST(TableRequirementsTest, AssertRefSnapshotIDChanged) { + auto metadata = CreateBaseMetadata(); + AddBranch(*metadata, "random_branch", 15); + + // Requirement expects snapshot ID 14, but actual is 15 + table::AssertRefSnapshotID requirement("random_branch", 14); + auto status = requirement.Validate(metadata.get()); + EXPECT_THAT(status, IsError(ErrorKind::kCommitFailed)); + EXPECT_THAT(status, HasErrorMessage("has changed")); +} + +// Replace Table Tests (less restrictive than Update) + +TEST(TableRequirementsTest, ReplaceTableDoesNotRequireCurrentSchemaID) { + auto metadata = CreateBaseMetadata(); + metadata->current_schema_id = 3; + + std::vector> updates; + updates.push_back(std::make_unique(5)); + + auto result = TableRequirements::ForReplaceTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Replace table should NOT add AssertCurrentSchemaID + EXPECT_EQ(CountRequirementsOfType(requirements), 0); +} + +TEST(TableRequirementsTest, ReplaceTableDoesNotRequireDefaultSpecID) { + auto metadata = CreateBaseMetadata(); + metadata->default_spec_id = 3; + + std::vector> updates; + updates.push_back(std::make_unique(5)); + + auto result = TableRequirements::ForReplaceTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Replace table should NOT add AssertDefaultSpecID + EXPECT_EQ(CountRequirementsOfType(requirements), 0); +} + +TEST(TableRequirementsTest, ReplaceTableDoesNotRequireDefaultSortOrderID) { + auto metadata = CreateBaseMetadata(); + metadata->default_sort_order_id = 3; + + std::vector> updates; + updates.push_back(std::make_unique(5)); + + auto result = TableRequirements::ForReplaceTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Replace table should NOT add AssertDefaultSortOrderID + EXPECT_EQ(CountRequirementsOfType(requirements), 0); +} + +TEST(TableRequirementsTest, ReplaceTableDoesNotAddBranchRequirements) { + auto metadata = CreateBaseMetadata(); + metadata->current_schema_id = 3; + AddBranch(*metadata, "branch", 42); + + std::vector> updates; + updates.push_back(std::make_unique(std::vector{1, 2})); + + auto result = TableRequirements::ForReplaceTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Replace table should NOT add AssertRefSnapshotID for branches + EXPECT_EQ(CountRequirementsOfType(requirements), 0); +} + +// Combined Updates Tests + +TEST(TableRequirementsTest, MultipleUpdatesDeduplication) { + auto metadata = CreateBaseMetadata(); + metadata->last_column_id = 1; + metadata->current_schema_id = 0; + + std::vector> updates; + auto schema = CreateTestSchema(); + // Add multiple AddSchema updates - should only generate one requirement + updates.push_back(std::make_unique(schema, 1)); + updates.push_back(std::make_unique(schema, 1)); + // Add multiple SetCurrentSchema updates - should only generate one requirement + updates.push_back(std::make_unique(0)); + updates.push_back(std::make_unique(1)); + + auto result = TableRequirements::ForUpdateTable(*metadata, updates); + ASSERT_THAT(result, IsOk()); + + auto& requirements = result.value(); + // Should have: 1 AssertUUID + 1 AssertLastAssignedFieldId + 1 AssertCurrentSchemaID + ASSERT_EQ(requirements.size(), 3); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); + EXPECT_EQ(CountRequirementsOfType(requirements), 1); +} + } // namespace iceberg diff --git a/src/iceberg/test/table_update_test.cc b/src/iceberg/test/table_update_test.cc index 298141a93..afa70894c 100644 --- a/src/iceberg/test/table_update_test.cc +++ b/src/iceberg/test/table_update_test.cc @@ -19,6 +19,7 @@ #include "iceberg/table_update.h" +#include #include #include #include @@ -26,23 +27,34 @@ #include #include "iceberg/partition_spec.h" +#include "iceberg/schema.h" #include "iceberg/snapshot.h" +#include "iceberg/sort_field.h" #include "iceberg/sort_order.h" #include "iceberg/table_metadata.h" #include "iceberg/table_requirement.h" #include "iceberg/table_requirements.h" #include "iceberg/test/matchers.h" +#include "iceberg/transform.h" +#include "iceberg/type.h" namespace iceberg { namespace { +// Helper function to create a simple schema for testing +std::shared_ptr CreateTestSchema() { + auto field1 = SchemaField::MakeRequired(1, "id", int32()); + auto field2 = SchemaField::MakeRequired(2, "data", string()); + auto field3 = SchemaField::MakeRequired(3, "ts", timestamp()); + return std::make_shared(std::vector{field1, field2, field3}, 0); +} + // Helper function to generate requirements std::vector> GenerateRequirements( const TableUpdate& update, const TableMetadata* base) { TableUpdateContext context(base, /*is_replace=*/false); - EXPECT_THAT(update.GenerateRequirements(context), IsOk()); - + update.GenerateRequirements(context), IsOk(); auto requirements = context.Build(); EXPECT_THAT(requirements, IsOk()); return std::move(requirements.value()); @@ -56,35 +68,302 @@ std::unique_ptr CreateBaseMetadata() { metadata->location = "s3://bucket/test"; metadata->last_sequence_number = 0; metadata->last_updated_ms = TimePointMs{std::chrono::milliseconds(1000)}; - metadata->last_column_id = 0; + metadata->last_column_id = 3; + metadata->current_schema_id = 0; + metadata->schemas.push_back(CreateTestSchema()); metadata->default_spec_id = PartitionSpec::kInitialSpecId; metadata->last_partition_id = 0; metadata->current_snapshot_id = Snapshot::kInvalidSnapshotId; metadata->default_sort_order_id = SortOrder::kInitialSortOrderId; + metadata->sort_orders.push_back(SortOrder::Unsorted()); metadata->next_row_id = TableMetadata::kInitialRowId; return metadata; } } // namespace -// Test GenerateRequirements for AssignUUID update -TEST(TableUpdateTest, AssignUUIDGenerateRequirements) { - table::AssignUUID update("new-uuid"); +// Parameter struct for testing GenerateRequirements behavior +struct GenerateRequirementsTestParam { + std::string test_name; + std::function()> update_factory; + // Expected number of requirements for existing table (new table always expects 0) + size_t expected_existing_table_count; + // Optional validator function to check the generated requirements + std::function>&, + const TableMetadata*)> + validator; +}; + +class GenerateRequirementsTest + : public ::testing::TestWithParam {}; - // New table - no requirements (AssignUUID doesn't generate requirements) - auto new_table_reqs = GenerateRequirements(update, nullptr); +TEST_P(GenerateRequirementsTest, GeneratesExpectedRequirements) { + const auto& param = GetParam(); + auto update = param.update_factory(); + + // New table - always no requirements + auto new_table_reqs = GenerateRequirements(*update, nullptr); EXPECT_TRUE(new_table_reqs.empty()); - // Existing table - AssignUUID doesn't generate requirements anymore - // The UUID assertion is added by ForUpdateTable/ForReplaceTable methods + // Existing table - check expected count + auto base = CreateBaseMetadata(); + auto existing_table_reqs = GenerateRequirements(*update, base.get()); + ASSERT_EQ(existing_table_reqs.size(), param.expected_existing_table_count); + + // Validate the requirements if validator is provided + if (param.validator) { + param.validator(existing_table_reqs, base.get()); + } +} + +INSTANTIATE_TEST_SUITE_P( + TableUpdateGenerateRequirements, GenerateRequirementsTest, + ::testing::Values( + // Updates that generate no requirements + GenerateRequirementsTestParam{ + .test_name = "AssignUUID", + .update_factory = + [] { return std::make_unique("new-uuid"); }, + .expected_existing_table_count = 0, + .validator = nullptr}, + GenerateRequirementsTestParam{ + .test_name = "UpgradeFormatVersion", + .update_factory = + [] { return std::make_unique(3); }, + .expected_existing_table_count = 0, + .validator = nullptr}, + GenerateRequirementsTestParam{ + .test_name = "AddSortOrder", + .update_factory = + [] { + auto schema = CreateTestSchema(); + SortField sort_field(1, Transform::Identity(), + SortDirection::kAscending, NullOrder::kFirst); + auto sort_order = + SortOrder::Make(*schema, 1, std::vector{sort_field}) + .value(); + return std::make_unique(std::move(sort_order)); + }, + .expected_existing_table_count = 0, + .validator = nullptr}, + GenerateRequirementsTestParam{.test_name = "AddSnapshot", + .update_factory = + [] { + auto snapshot = std::make_shared(); + return std::make_unique( + snapshot); + }, + .expected_existing_table_count = 0, + .validator = nullptr}, + GenerateRequirementsTestParam{ + .test_name = "RemoveSnapshotRef", + .update_factory = + [] { return std::make_unique("my-branch"); }, + .expected_existing_table_count = 0, + .validator = nullptr}, + GenerateRequirementsTestParam{ + .test_name = "SetProperties", + .update_factory = + [] { + return std::make_unique( + std::unordered_map{{"key", "value"}}); + }, + .expected_existing_table_count = 0, + .validator = nullptr}, + GenerateRequirementsTestParam{ + .test_name = "RemoveProperties", + .update_factory = + [] { + return std::make_unique( + std::vector{"key"}); + }, + .expected_existing_table_count = 0, + .validator = nullptr}, + GenerateRequirementsTestParam{ + .test_name = "SetLocation", + .update_factory = + [] { return std::make_unique("s3://new/location"); }, + .expected_existing_table_count = 0, + .validator = nullptr}, + + // Updates that generate single requirement for existing tables + GenerateRequirementsTestParam{ + .test_name = "AddSchema", + .update_factory = + [] { + auto new_schema = std::make_shared( + std::vector{ + SchemaField::MakeRequired(4, "new_col", string())}, + 3); + return std::make_unique(new_schema, 3); + }, + .expected_existing_table_count = 1, + .validator = + [](const std::vector>& reqs, + const TableMetadata* base) { + auto* assert_id = dynamic_cast( + reqs[0].get()); + ASSERT_NE(assert_id, nullptr); + EXPECT_EQ(assert_id->last_assigned_field_id(), base->last_column_id); + }}, + GenerateRequirementsTestParam{ + .test_name = "SetCurrentSchema", + .update_factory = [] { return std::make_unique(1); }, + .expected_existing_table_count = 1, + .validator = + [](const std::vector>& reqs, + const TableMetadata* base) { + auto* assert_id = + dynamic_cast(reqs[0].get()); + ASSERT_NE(assert_id, nullptr); + EXPECT_EQ(assert_id->schema_id(), base->current_schema_id); + }}, + GenerateRequirementsTestParam{ + .test_name = "AddPartitionSpec", + .update_factory = + [] { + PartitionField partition_field(1, 1, "id_identity", + Transform::Identity()); + auto spec = std::shared_ptr( + PartitionSpec::Make(1, {partition_field}).value().release()); + return std::make_unique(spec); + }, + .expected_existing_table_count = 1, + .validator = + [](const std::vector>& reqs, + const TableMetadata* base) { + auto* assert_id = + dynamic_cast( + reqs[0].get()); + ASSERT_NE(assert_id, nullptr); + EXPECT_EQ(assert_id->last_assigned_partition_id(), + base->last_partition_id); + }}, + GenerateRequirementsTestParam{ + .test_name = "SetDefaultPartitionSpec", + .update_factory = + [] { return std::make_unique(1); }, + .expected_existing_table_count = 1, + .validator = + [](const std::vector>& reqs, + const TableMetadata* base) { + auto* assert_id = + dynamic_cast(reqs[0].get()); + ASSERT_NE(assert_id, nullptr); + EXPECT_EQ(assert_id->spec_id(), base->default_spec_id); + }}, + GenerateRequirementsTestParam{ + .test_name = "SetDefaultSortOrder", + .update_factory = + [] { return std::make_unique(1); }, + .expected_existing_table_count = 1, + .validator = + [](const std::vector>& reqs, + const TableMetadata* base) { + auto* assert_sort_order = + dynamic_cast(reqs[0].get()); + ASSERT_NE(assert_sort_order, nullptr); + EXPECT_EQ(assert_sort_order->sort_order_id(), + base->default_sort_order_id); + }}, + GenerateRequirementsTestParam{ + .test_name = "RemovePartitionSpecs", + .update_factory = + [] { + return std::make_unique( + std::vector{1}); + }, + .expected_existing_table_count = 1, + .validator = + [](const std::vector>& reqs, + const TableMetadata* base) { + auto* assert_id = + dynamic_cast(reqs[0].get()); + ASSERT_NE(assert_id, nullptr); + EXPECT_EQ(assert_id->spec_id(), base->default_spec_id); + }}, + GenerateRequirementsTestParam{ + .test_name = "RemoveSchemas", + .update_factory = + [] { + return std::make_unique(std::vector{1}); + }, + .expected_existing_table_count = 1, + .validator = + [](const std::vector>& reqs, + const TableMetadata* base) { + auto* assert_id = + dynamic_cast(reqs[0].get()); + ASSERT_NE(assert_id, nullptr); + EXPECT_EQ(assert_id->schema_id(), base->current_schema_id); + }}), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + +// Test AssignUUID ApplyTo +TEST(TableUpdateTest, AssignUUIDApplyUpdate) { auto base = CreateBaseMetadata(); - auto existing_table_reqs = GenerateRequirements(update, base.get()); - EXPECT_TRUE(existing_table_reqs.empty()); + auto builder = TableMetadataBuilder::BuildFrom(base.get()); + + // Apply AssignUUID update + table::AssignUUID uuid_update("apply-uuid"); + uuid_update.ApplyTo(*builder); + + ICEBERG_UNWRAP_OR_FAIL(auto metadata, builder->Build()); + EXPECT_EQ(metadata->table_uuid, "apply-uuid"); +} + +// Test AddSortOrder ApplyTo +TEST(TableUpdateTest, AddSortOrderApplyUpdate) { + auto base = CreateBaseMetadata(); + auto builder = TableMetadataBuilder::BuildFrom(base.get()); + + // Create a sort order + auto schema = CreateTestSchema(); + SortField sort_field(1, Transform::Identity(), SortDirection::kAscending, + NullOrder::kFirst); + auto sort_order = std::shared_ptr( + SortOrder::Make(*schema, 1, std::vector{sort_field}).value().release()); + + // Apply AddSortOrder update + table::AddSortOrder add_sort_order(sort_order); + add_sort_order.ApplyTo(*builder); + + ICEBERG_UNWRAP_OR_FAIL(auto metadata, builder->Build()); + + // Verify the sort order was added + ASSERT_EQ(metadata->sort_orders.size(), 2); // unsorted + new order + auto& added_order = metadata->sort_orders[1]; + EXPECT_EQ(added_order->order_id(), 1); + EXPECT_EQ(added_order->fields().size(), 1); + EXPECT_EQ(added_order->fields()[0].source_id(), 1); + EXPECT_EQ(added_order->fields()[0].direction(), SortDirection::kAscending); + EXPECT_EQ(added_order->fields()[0].null_order(), NullOrder::kFirst); +} + +// Test SetDefaultSortOrder ApplyTo +TEST(TableUpdateTest, SetDefaultSortOrderApplyUpdate) { + auto base = CreateBaseMetadata(); + + // add a sort order to the base metadata + auto schema = CreateTestSchema(); + SortField sort_field(1, Transform::Identity(), SortDirection::kDescending, + NullOrder::kLast); + auto sort_order = std::shared_ptr( + SortOrder::Make(*schema, 1, std::vector{sort_field}).value().release()); + base->sort_orders.push_back(sort_order); + + auto builder = TableMetadataBuilder::BuildFrom(base.get()); + + // Apply SetDefaultSortOrder update to set the new sort order as default + table::SetDefaultSortOrder set_default_sort_order(1); + set_default_sort_order.ApplyTo(*builder); + + ICEBERG_UNWRAP_OR_FAIL(auto metadata, builder->Build()); - // Existing table with empty UUID - no requirements - base->table_uuid = ""; - auto empty_uuid_reqs = GenerateRequirements(update, base.get()); - EXPECT_TRUE(empty_uuid_reqs.empty()); + // Verify the default sort order was changed + EXPECT_EQ(metadata->default_sort_order_id, 1); } } // namespace iceberg diff --git a/src/iceberg/test/transform_test.cc b/src/iceberg/test/transform_test.cc index 821edac55..7f0514df4 100644 --- a/src/iceberg/test/transform_test.cc +++ b/src/iceberg/test/transform_test.cc @@ -36,7 +36,6 @@ #include "iceberg/type.h" #include "iceberg/util/checked_cast.h" #include "iceberg/util/formatter.h" // IWYU pragma: keep -#include "iceberg/util/macros.h" namespace iceberg { @@ -954,12 +953,12 @@ TEST_F(TransformProjectTest, IdentityProjectEquality) { // Test equality predicate auto unbound = Expressions::Equal("value", Literal::Int(100)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); EXPECT_EQ(projected->op(), Expression::Operation::kEq); @@ -977,23 +976,23 @@ TEST_F(TransformProjectTest, IdentityProjectComparison) { // Test less than predicate auto unbound_lt = Expressions::LessThan("value", Literal::Int(50)); - ICEBERG_ASSIGN_OR_THROW(auto bound_lt, - unbound_lt->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_lt, + unbound_lt->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred_lt = std::dynamic_pointer_cast(bound_lt); ASSERT_NE(bound_pred_lt, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_lt, transform->Project("part", bound_pred_lt)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_lt, transform->Project("part", bound_pred_lt)); ASSERT_NE(projected_lt, nullptr); EXPECT_EQ(projected_lt->op(), Expression::Operation::kLt); // Test greater than or equal predicate auto unbound_gte = Expressions::GreaterThanOrEqual("value", Literal::Int(100)); - ICEBERG_ASSIGN_OR_THROW(auto bound_gte, - unbound_gte->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_gte, + unbound_gte->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred_gte = std::dynamic_pointer_cast(bound_gte); ASSERT_NE(bound_pred_gte, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_gte, transform->Project("part", bound_pred_gte)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_gte, transform->Project("part", bound_pred_gte)); ASSERT_NE(projected_gte, nullptr); EXPECT_EQ(projected_gte->op(), Expression::Operation::kGtEq); } @@ -1003,13 +1002,13 @@ TEST_F(TransformProjectTest, IdentityProjectUnary) { // Test IsNull predicate auto unbound_null = Expressions::IsNull("value"); - ICEBERG_ASSIGN_OR_THROW(auto bound_null, - unbound_null->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_null, + unbound_null->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred_null = std::dynamic_pointer_cast(bound_null); ASSERT_NE(bound_pred_null, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_null, - transform->Project("part", bound_pred_null)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_null, + transform->Project("part", bound_pred_null)); ASSERT_NE(projected_null, nullptr); EXPECT_EQ(projected_null->op(), Expression::Operation::kIsNull); } @@ -1020,12 +1019,12 @@ TEST_F(TransformProjectTest, IdentityProjectSet) { // Test IN predicate auto unbound_in = Expressions::In("value", {Literal::Int(1), Literal::Int(2), Literal::Int(3)}); - ICEBERG_ASSIGN_OR_THROW(auto bound_in, - unbound_in->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_in, + unbound_in->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred_in = std::dynamic_pointer_cast(bound_in); ASSERT_NE(bound_pred_in, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_in, transform->Project("part", bound_pred_in)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_in, transform->Project("part", bound_pred_in)); ASSERT_NE(projected_in, nullptr); EXPECT_EQ(projected_in->op(), Expression::Operation::kIn); auto unbound_projected = @@ -1046,12 +1045,12 @@ TEST_F(TransformProjectTest, BucketProjectEquality) { // Bucket can project equality predicates auto unbound = Expressions::Equal("value", Literal::Int(34)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); EXPECT_EQ(projected->op(), Expression::Operation::kEq); @@ -1070,8 +1069,8 @@ TEST_F(TransformProjectTest, BucketProjectWithMatchingTransformedChild) { // Create a predicate like: bucket(value, 16) = 5 auto bucket_term = Expressions::Bucket("value", 16); auto unbound = Expressions::Equal(bucket_term, Literal::Int(5)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); @@ -1080,8 +1079,8 @@ TEST_F(TransformProjectTest, BucketProjectWithMatchingTransformedChild) { // When the transform matches, Project should use RemoveTransform and return the // predicate - ICEBERG_ASSIGN_OR_THROW(auto projected, - partition_transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, + partition_transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); EXPECT_EQ(projected->op(), Expression::Operation::kEq); auto unbound_projected = @@ -1098,12 +1097,12 @@ TEST_F(TransformProjectTest, BucketProjectComparisonReturnsNull) { // Bucket cannot project comparison predicates (they return null) auto unbound_lt = Expressions::LessThan("value", Literal::Int(50)); - ICEBERG_ASSIGN_OR_THROW(auto bound_lt, - unbound_lt->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_lt, + unbound_lt->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred_lt = std::dynamic_pointer_cast(bound_lt); ASSERT_NE(bound_pred_lt, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_lt, transform->Project("part", bound_pred_lt)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_lt, transform->Project("part", bound_pred_lt)); EXPECT_EQ(projected_lt, nullptr); } @@ -1113,12 +1112,12 @@ TEST_F(TransformProjectTest, BucketProjectInSet) { // Bucket can project IN predicates auto unbound_in = Expressions::In("value", {Literal::Int(1), Literal::Int(2), Literal::Int(3)}); - ICEBERG_ASSIGN_OR_THROW(auto bound_in, - unbound_in->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_in, + unbound_in->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred_in = std::dynamic_pointer_cast(bound_in); ASSERT_NE(bound_pred_in, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_in, transform->Project("part", bound_pred_in)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_in, transform->Project("part", bound_pred_in)); ASSERT_NE(projected_in, nullptr); EXPECT_EQ(projected_in->op(), Expression::Operation::kIn); } @@ -1129,13 +1128,13 @@ TEST_F(TransformProjectTest, BucketProjectNotInReturnsNull) { // Bucket cannot project NOT IN predicates auto unbound_not_in = Expressions::NotIn("value", {Literal::Int(1), Literal::Int(2), Literal::Int(3)}); - ICEBERG_ASSIGN_OR_THROW(auto bound_not_in, - unbound_not_in->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_not_in, + unbound_not_in->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred_not_in = std::dynamic_pointer_cast(bound_not_in); ASSERT_NE(bound_pred_not_in, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_not_in, - transform->Project("part", bound_pred_not_in)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_not_in, + transform->Project("part", bound_pred_not_in)); EXPECT_EQ(projected_not_in, nullptr); } @@ -1144,12 +1143,12 @@ TEST_F(TransformProjectTest, TruncateProjectIntEquality) { // Truncate can project equality predicates auto unbound = Expressions::Equal("value", Literal::Int(123)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); EXPECT_EQ(projected->op(), Expression::Operation::kEq); @@ -1167,12 +1166,12 @@ TEST_F(TransformProjectTest, TruncateProjectIntLessThan) { // Truncate projects LT as LTE auto unbound = Expressions::LessThan("value", Literal::Int(25)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); EXPECT_EQ(projected->op(), Expression::Operation::kLtEq); } @@ -1182,12 +1181,12 @@ TEST_F(TransformProjectTest, TruncateProjectIntGreaterThan) { // Truncate projects GT as GTE auto unbound = Expressions::GreaterThan("value", Literal::Int(25)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); EXPECT_EQ(projected->op(), Expression::Operation::kGtEq); @@ -1204,12 +1203,12 @@ TEST_F(TransformProjectTest, TruncateProjectStringEquality) { auto transform = Transform::Truncate(5); auto unbound = Expressions::Equal("value", Literal::String("Hello, World!")); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*string_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); EXPECT_EQ(projected->op(), Expression::Operation::kEq); @@ -1228,13 +1227,13 @@ TEST_F(TransformProjectTest, TruncateProjectStringStartsWith) { // StartsWith with shorter string than width auto unbound_short = Expressions::StartsWith("value", "Hi"); - ICEBERG_ASSIGN_OR_THROW(auto bound_short, - unbound_short->Bind(*string_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_short, + unbound_short->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred_short = std::dynamic_pointer_cast(bound_short); ASSERT_NE(bound_pred_short, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_short, - transform->Project("part", bound_pred_short)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_short, + transform->Project("part", bound_pred_short)); ASSERT_NE(projected_short, nullptr); EXPECT_EQ(projected_short->op(), Expression::Operation::kStartsWith); @@ -1249,13 +1248,13 @@ TEST_F(TransformProjectTest, TruncateProjectStringStartsWith) { // StartsWith with string equal to width auto unbound_equal = Expressions::StartsWith("value", "Hello"); - ICEBERG_ASSIGN_OR_THROW(auto bound_equal, - unbound_equal->Bind(*string_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_equal, + unbound_equal->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred_equal = std::dynamic_pointer_cast(bound_equal); ASSERT_NE(bound_pred_equal, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_equal, - transform->Project("part", bound_pred_equal)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_equal, + transform->Project("part", bound_pred_equal)); ASSERT_NE(projected_equal, nullptr); EXPECT_EQ(projected_equal->op(), Expression::Operation::kEq); @@ -1275,15 +1274,15 @@ TEST_F(TransformProjectTest, TruncateProjectStringStartsWithCodePointCountLessTh // Code point count < width (multi-byte UTF-8 characters) // "😜🧐" has 2 code points, width is 5 auto unbound_emoji_short = Expressions::StartsWith("value", "😜🧐"); - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto bound_emoji_short, unbound_emoji_short->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred_emoji_short = std::dynamic_pointer_cast(bound_emoji_short); ASSERT_NE(bound_pred_emoji_short, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_emoji_short, - transform->Project("part", bound_pred_emoji_short)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_emoji_short, + transform->Project("part", bound_pred_emoji_short)); ASSERT_NE(projected_emoji_short, nullptr); EXPECT_EQ(projected_emoji_short->op(), Expression::Operation::kStartsWith); @@ -1304,15 +1303,15 @@ TEST_F(TransformProjectTest, TruncateProjectStringStartsWithCodePointCountEqualT // Code point count == width (exactly 5 code points) // "😜🧐🤔🤪🥳" has exactly 5 code points auto unbound_emoji_equal = Expressions::StartsWith("value", "😜🧐🤔🤪🥳"); - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto bound_emoji_equal, unbound_emoji_equal->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred_emoji_equal = std::dynamic_pointer_cast(bound_emoji_equal); ASSERT_NE(bound_pred_emoji_equal, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_emoji_equal, - transform->Project("part", bound_pred_emoji_equal)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_emoji_equal, + transform->Project("part", bound_pred_emoji_equal)); ASSERT_NE(projected_emoji_equal, nullptr); EXPECT_EQ(projected_emoji_equal->op(), Expression::Operation::kEq); @@ -1335,15 +1334,15 @@ TEST_F(TransformProjectTest, // "😜🧐🤔🤪🥳😵‍💫😂" has 7 code points, should truncate to 5 auto unbound_emoji_long = Expressions::StartsWith("value", "😜🧐🤔🤪🥳😵‍💫😂"); - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto bound_emoji_long, unbound_emoji_long->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred_emoji_long = std::dynamic_pointer_cast(bound_emoji_long); ASSERT_NE(bound_pred_emoji_long, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_emoji_long, - transform->Project("part", bound_pred_emoji_long)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_emoji_long, + transform->Project("part", bound_pred_emoji_long)); ASSERT_NE(projected_emoji_long, nullptr); EXPECT_EQ(projected_emoji_long->op(), Expression::Operation::kStartsWith); @@ -1364,15 +1363,15 @@ TEST_F(TransformProjectTest, TruncateProjectStringStartsWithMixedAsciiAndMultiBy // Mixed ASCII and multi-byte UTF-8 characters // "a😜b🧐c" has 5 code points (3 ASCII + 2 emojis) auto unbound_mixed_equal = Expressions::StartsWith("value", "a😜b🧐c"); - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto bound_mixed_equal, unbound_mixed_equal->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred_mixed_equal = std::dynamic_pointer_cast(bound_mixed_equal); ASSERT_NE(bound_pred_mixed_equal, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_mixed_equal, - transform->Project("part", bound_pred_mixed_equal)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_mixed_equal, + transform->Project("part", bound_pred_mixed_equal)); ASSERT_NE(projected_mixed_equal, nullptr); EXPECT_EQ(projected_mixed_equal->op(), Expression::Operation::kEq); @@ -1393,15 +1392,15 @@ TEST_F(TransformProjectTest, TruncateProjectStringStartsWithChineseCharactersSho // Chinese characters (3-byte UTF-8) // "你好世界" has 4 code points, width is 5 auto unbound_chinese_short = Expressions::StartsWith("value", "你好世界"); - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto bound_chinese_short, unbound_chinese_short->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred_chinese_short = std::dynamic_pointer_cast(bound_chinese_short); ASSERT_NE(bound_pred_chinese_short, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_chinese_short, - transform->Project("part", bound_pred_chinese_short)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_chinese_short, + transform->Project("part", bound_pred_chinese_short)); ASSERT_NE(projected_chinese_short, nullptr); EXPECT_EQ(projected_chinese_short->op(), Expression::Operation::kStartsWith); @@ -1422,15 +1421,15 @@ TEST_F(TransformProjectTest, TruncateProjectStringStartsWithChineseCharactersEqu // Chinese characters exactly matching width // "你好世界好" has exactly 5 code points auto unbound_chinese_equal = Expressions::StartsWith("value", "你好世界好"); - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto bound_chinese_equal, unbound_chinese_equal->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred_chinese_equal = std::dynamic_pointer_cast(bound_chinese_equal); ASSERT_NE(bound_pred_chinese_equal, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_chinese_equal, - transform->Project("part", bound_pred_chinese_equal)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_chinese_equal, + transform->Project("part", bound_pred_chinese_equal)); ASSERT_NE(projected_chinese_equal, nullptr); EXPECT_EQ(projected_chinese_equal->op(), Expression::Operation::kEq); @@ -1452,15 +1451,15 @@ TEST_F(TransformProjectTest, // NotStartsWith with code point count == width // Should convert to NotEq auto unbound_not_starts_equal = Expressions::NotStartsWith("value", "😜🧐🤔🤪🥳"); - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto bound_not_starts_equal, unbound_not_starts_equal->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred_not_starts_equal = std::dynamic_pointer_cast(bound_not_starts_equal); ASSERT_NE(bound_pred_not_starts_equal, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_not_starts_equal, - transform->Project("part", bound_pred_not_starts_equal)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_not_starts_equal, + transform->Project("part", bound_pred_not_starts_equal)); ASSERT_NE(projected_not_starts_equal, nullptr); EXPECT_EQ(projected_not_starts_equal->op(), Expression::Operation::kNotEq); @@ -1482,15 +1481,15 @@ TEST_F(TransformProjectTest, // NotStartsWith with code point count < width // Should remain NotStartsWith auto unbound_not_starts_short = Expressions::NotStartsWith("value", "😜🧐"); - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto bound_not_starts_short, unbound_not_starts_short->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred_not_starts_short = std::dynamic_pointer_cast(bound_not_starts_short); ASSERT_NE(bound_pred_not_starts_short, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_not_starts_short, - transform->Project("part", bound_pred_not_starts_short)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_not_starts_short, + transform->Project("part", bound_pred_not_starts_short)); ASSERT_NE(projected_not_starts_short, nullptr); EXPECT_EQ(projected_not_starts_short->op(), Expression::Operation::kNotStartsWith); @@ -1514,15 +1513,15 @@ TEST_F(TransformProjectTest, // Should return nullptr (cannot project) auto unbound_not_starts_long = Expressions::NotStartsWith("value", "😜🧐🤔🤪🥳😵‍💫😂"); - ICEBERG_ASSIGN_OR_THROW( + ICEBERG_UNWRAP_OR_FAIL( auto bound_not_starts_long, unbound_not_starts_long->Bind(*string_schema_, /*case_sensitive=*/true)); auto bound_pred_not_starts_long = std::dynamic_pointer_cast(bound_not_starts_long); ASSERT_NE(bound_pred_not_starts_long, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_not_starts_long, - transform->Project("part", bound_pred_not_starts_long)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_not_starts_long, + transform->Project("part", bound_pred_not_starts_long)); EXPECT_EQ(projected_not_starts_long, nullptr); } @@ -1533,12 +1532,12 @@ TEST_F(TransformProjectTest, YearProjectEquality) { int32_t date_value = TemporalTestHelper::CreateDate({.year = 2021, .month = 6, .day = 1}); auto unbound = Expressions::Equal("value", Literal::Date(date_value)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*date_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*date_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); EXPECT_EQ(projected->op(), Expression::Operation::kEq); } @@ -1551,23 +1550,23 @@ TEST_F(TransformProjectTest, YearProjectComparison) { // LT projects to LTE auto unbound_lt = Expressions::LessThan("value", Literal::Date(date_value)); - ICEBERG_ASSIGN_OR_THROW(auto bound_lt, - unbound_lt->Bind(*date_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_lt, + unbound_lt->Bind(*date_schema_, /*case_sensitive=*/true)); auto bound_pred_lt = std::dynamic_pointer_cast(bound_lt); ASSERT_NE(bound_pred_lt, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_lt, transform->Project("part", bound_pred_lt)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_lt, transform->Project("part", bound_pred_lt)); ASSERT_NE(projected_lt, nullptr); EXPECT_EQ(projected_lt->op(), Expression::Operation::kLtEq); // GT projects to GTE auto unbound_gt = Expressions::GreaterThan("value", Literal::Date(date_value)); - ICEBERG_ASSIGN_OR_THROW(auto bound_gt, - unbound_gt->Bind(*date_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_gt, + unbound_gt->Bind(*date_schema_, /*case_sensitive=*/true)); auto bound_pred_gt = std::dynamic_pointer_cast(bound_gt); ASSERT_NE(bound_pred_gt, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_gt, transform->Project("part", bound_pred_gt)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_gt, transform->Project("part", bound_pred_gt)); ASSERT_NE(projected_gt, nullptr); EXPECT_EQ(projected_gt->op(), Expression::Operation::kGtEq); } @@ -1578,12 +1577,12 @@ TEST_F(TransformProjectTest, MonthProjectEquality) { int64_t ts_value = TemporalTestHelper::CreateTimestamp({.year = 2021, .month = 6, .day = 1}); auto unbound = Expressions::Equal("value", Literal::Timestamp(ts_value)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*timestamp_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*timestamp_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); EXPECT_EQ(projected->op(), Expression::Operation::kEq); } @@ -1594,12 +1593,12 @@ TEST_F(TransformProjectTest, DayProjectEquality) { int32_t date_value = TemporalTestHelper::CreateDate({.year = 2021, .month = 6, .day = 15}); auto unbound = Expressions::Equal("value", Literal::Date(date_value)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*date_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*date_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); EXPECT_EQ(projected->op(), Expression::Operation::kEq); } @@ -1610,12 +1609,12 @@ TEST_F(TransformProjectTest, HourProjectEquality) { int64_t ts_value = TemporalTestHelper::CreateTimestamp( {.year = 2021, .month = 6, .day = 1, .hour = 14, .minute = 30}); auto unbound = Expressions::Equal("value", Literal::Timestamp(ts_value)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*timestamp_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*timestamp_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); EXPECT_EQ(projected->op(), Expression::Operation::kEq); } @@ -1624,13 +1623,13 @@ TEST_F(TransformProjectTest, VoidProjectReturnsNull) { auto transform = Transform::Void(); auto unbound = Expressions::Equal("value", Literal::Int(100)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); // Void transform always returns null (no projection possible) - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); EXPECT_EQ(projected, nullptr); } @@ -1643,12 +1642,12 @@ TEST_F(TransformProjectTest, TemporalProjectInSet) { auto unbound_in = Expressions::In( "value", {Literal::Date(date1), Literal::Date(date2), Literal::Date(date3)}); - ICEBERG_ASSIGN_OR_THROW(auto bound_in, - unbound_in->Bind(*date_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_in, + unbound_in->Bind(*date_schema_, /*case_sensitive=*/true)); auto bound_pred_in = std::dynamic_pointer_cast(bound_in); ASSERT_NE(bound_pred_in, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected_in, transform->Project("part", bound_pred_in)); + ICEBERG_UNWRAP_OR_FAIL(auto projected_in, transform->Project("part", bound_pred_in)); ASSERT_NE(projected_in, nullptr); EXPECT_EQ(projected_in->op(), Expression::Operation::kIn); } @@ -1663,12 +1662,12 @@ TEST_F(TransformProjectTest, DayTimestampProjectionFix) { // If we fix (for buggy writers), we project to day <= 0. auto unbound = Expressions::LessThan("value", Literal::Timestamp(0)); - ICEBERG_ASSIGN_OR_THROW(auto bound, - unbound->Bind(*timestamp_schema_, /*case_sensitive=*/true)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*timestamp_schema_, /*case_sensitive=*/true)); auto bound_pred = std::dynamic_pointer_cast(bound); ASSERT_NE(bound_pred, nullptr); - ICEBERG_ASSIGN_OR_THROW(auto projected, transform->Project("part", bound_pred)); + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->Project("part", bound_pred)); ASSERT_NE(projected, nullptr); auto unbound_projected = @@ -1680,4 +1679,560 @@ TEST_F(TransformProjectTest, DayTimestampProjectionFix) { EXPECT_EQ(val, 0) << "Expected projected value to be 0 (fix applied), but got " << val; } +// Test fixture for Transform::ProjectStrict tests +class TransformProjectStrictTest : public ::testing::Test { + protected: + void SetUp() override { + // Create test schemas for different source types + int_schema_ = std::make_shared( + std::vector{SchemaField::MakeRequired(1, "value", int32())}, + /*schema_id=*/0); + long_schema_ = std::make_shared( + std::vector{SchemaField::MakeRequired(1, "value", int64())}, + /*schema_id=*/0); + string_schema_ = std::make_shared( + std::vector{SchemaField::MakeRequired(1, "value", string())}, + /*schema_id=*/0); + date_schema_ = std::make_shared( + std::vector{SchemaField::MakeRequired(1, "value", date())}, + /*schema_id=*/0); + timestamp_schema_ = std::make_shared( + std::vector{SchemaField::MakeRequired(1, "value", timestamp())}, + /*schema_id=*/0); + decimal_schema_ = std::make_shared( + std::vector{SchemaField::MakeRequired(1, "value", decimal(9, 2))}, + /*schema_id=*/0); + } + + std::shared_ptr int_schema_; + std::shared_ptr long_schema_; + std::shared_ptr string_schema_; + std::shared_ptr date_schema_; + std::shared_ptr timestamp_schema_; + std::shared_ptr decimal_schema_; +}; + +TEST_F(TransformProjectStrictTest, IdentityStrictProjection) { + auto transform = Transform::Identity(); + + // Identity strict projection should behave the same as inclusive projection + auto unbound = Expressions::Equal("value", Literal::Int(100)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kEq); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kEq); + EXPECT_EQ(unbound_projected->literals().size(), 1); + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 100); +} + +TEST_F(TransformProjectStrictTest, BucketStrictEqualityReturnsFalse) { + auto transform = Transform::Bucket(10); + + // Bucket strict projection: equality should return FALSE (cannot guarantee equality) + auto unbound = Expressions::Equal("value", Literal::Int(100)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + EXPECT_EQ(projected, nullptr); +} + +TEST_F(TransformProjectStrictTest, BucketStrictNotEqual) { + auto transform = Transform::Bucket(10); + + // Bucket strict projection: notEqual can be projected + auto unbound = Expressions::NotEqual("value", Literal::Int(100)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kNotEq); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kNotEq); + EXPECT_EQ(unbound_projected->literals().size(), 1); + // bucket(100, 10) = 6 + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 6); +} + +TEST_F(TransformProjectStrictTest, BucketStrictComparisonReturnsNull) { + auto transform = Transform::Bucket(10); + + // Bucket strict projection: comparison predicates return null + auto unbound_lt = Expressions::LessThan("value", Literal::Int(100)); + ICEBERG_UNWRAP_OR_FAIL(auto bound_lt, + unbound_lt->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred_lt = std::dynamic_pointer_cast(bound_lt); + ASSERT_NE(bound_pred_lt, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected_lt, + transform->ProjectStrict("part", bound_pred_lt)); + EXPECT_EQ(projected_lt, nullptr); +} + +TEST_F(TransformProjectStrictTest, BucketStrictNotIn) { + auto transform = Transform::Bucket(10); + + // Bucket strict projection: NOT_IN can be projected + auto unbound_not_in = Expressions::NotIn( + "value", {Literal::Int(99), Literal::Int(100), Literal::Int(101)}); + ICEBERG_UNWRAP_OR_FAIL(auto bound_not_in, + unbound_not_in->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred_not_in = std::dynamic_pointer_cast(bound_not_in); + ASSERT_NE(bound_pred_not_in, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected_not_in, + transform->ProjectStrict("part", bound_pred_not_in)); + ASSERT_NE(projected_not_in, nullptr); + EXPECT_EQ(projected_not_in->op(), Expression::Operation::kNotIn); +} + +TEST_F(TransformProjectStrictTest, BucketStrictInReturnsNull) { + auto transform = Transform::Bucket(10); + + // Bucket strict projection: IN returns null (cannot guarantee) + auto unbound_in = + Expressions::In("value", {Literal::Int(99), Literal::Int(100), Literal::Int(101)}); + ICEBERG_UNWRAP_OR_FAIL(auto bound_in, + unbound_in->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred_in = std::dynamic_pointer_cast(bound_in); + ASSERT_NE(bound_pred_in, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected_in, + transform->ProjectStrict("part", bound_pred_in)); + EXPECT_EQ(projected_in, nullptr); +} + +TEST_F(TransformProjectStrictTest, BucketStrictString) { + auto transform = Transform::Bucket(10); + + // Bucket strict projection for string + auto unbound_not_eq = Expressions::NotEqual("value", Literal::String("abcdefg")); + ICEBERG_UNWRAP_OR_FAIL(auto bound_not_eq, + unbound_not_eq->Bind(*string_schema_, /*case_sensitive=*/true)); + auto bound_pred_not_eq = std::dynamic_pointer_cast(bound_not_eq); + ASSERT_NE(bound_pred_not_eq, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected_not_eq, + transform->ProjectStrict("part", bound_pred_not_eq)); + ASSERT_NE(projected_not_eq, nullptr); + EXPECT_EQ(projected_not_eq->op(), Expression::Operation::kNotEq); +} + +TEST_F(TransformProjectStrictTest, TruncateStrictIntEqualityReturnsNull) { + auto transform = Transform::Truncate(10); + + // Truncate strict projection: equality returns null (cannot guarantee) + auto unbound = Expressions::Equal("value", Literal::Int(123)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + EXPECT_EQ(projected, nullptr); +} + +TEST_F(TransformProjectStrictTest, TruncateStrictIntLessThan) { + auto transform = Transform::Truncate(10); + + // Truncate strict projection: LT projects to LT + auto unbound = Expressions::LessThan("value", Literal::Int(100)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kLt); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kLt); + EXPECT_EQ(unbound_projected->literals().size(), 1); + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 100); +} + +TEST_F(TransformProjectStrictTest, TruncateStrictIntLessThanOrEqual) { + auto transform = Transform::Truncate(10); + + // Truncate strict projection: LTE projects to LT + auto unbound = Expressions::LessThanOrEqual("value", Literal::Int(100)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kLt); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kLt); + EXPECT_EQ(unbound_projected->literals().size(), 1); + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 100); +} + +TEST_F(TransformProjectStrictTest, TruncateStrictIntGreaterThan) { + auto transform = Transform::Truncate(10); + + // Truncate strict projection: GT projects to GT + auto unbound = Expressions::GreaterThan("value", Literal::Int(100)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kGt); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kGt); + EXPECT_EQ(unbound_projected->literals().size(), 1); + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 100); +} + +TEST_F(TransformProjectStrictTest, TruncateStrictIntGreaterThanOrEqualLowerBound) { + auto transform = Transform::Truncate(10); + + // Truncate strict projection: GTE projects to GT (lower bound, value = 100) + auto unbound = Expressions::GreaterThanOrEqual("value", Literal::Int(100)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kGt); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kGt); + EXPECT_EQ(unbound_projected->literals().size(), 1); + // For GTE with value 100 and width 10, truncate(100) = 100, so GT should be 90 + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 90); +} + +TEST_F(TransformProjectStrictTest, TruncateStrictIntGreaterThanOrEqualUpperBound) { + auto transform = Transform::Truncate(10); + + // Truncate strict projection: GTE projects to GT (upper bound, value = 99) + auto unbound = Expressions::GreaterThanOrEqual("value", Literal::Int(99)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kGt); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kGt); + EXPECT_EQ(unbound_projected->literals().size(), 1); + // For GTE with value 99 and width 10, truncate(99) = 90, so GT should be 90 + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 90); +} + +TEST_F(TransformProjectStrictTest, TruncateStrictIntNotEqual) { + auto transform = Transform::Truncate(10); + + // Truncate strict projection: notEqual can be projected + auto unbound = Expressions::NotEqual("value", Literal::Int(100)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kNotEq); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kNotEq); + EXPECT_EQ(unbound_projected->literals().size(), 1); + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 100); +} + +TEST_F(TransformProjectStrictTest, TruncateStrictIntNotIn) { + auto transform = Transform::Truncate(10); + + // Truncate strict projection: NOT_IN can be projected + auto unbound_not_in = Expressions::NotIn( + "value", {Literal::Int(99), Literal::Int(100), Literal::Int(101)}); + ICEBERG_UNWRAP_OR_FAIL(auto bound_not_in, + unbound_not_in->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred_not_in = std::dynamic_pointer_cast(bound_not_in); + ASSERT_NE(bound_pred_not_in, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected_not_in, + transform->ProjectStrict("part", bound_pred_not_in)); + ASSERT_NE(projected_not_in, nullptr); + EXPECT_EQ(projected_not_in->op(), Expression::Operation::kNotIn); +} + +TEST_F(TransformProjectStrictTest, TruncateStrictString) { + auto transform = Transform::Truncate(5); + + // Truncate strict projection for string + auto unbound_lt = Expressions::LessThan("value", Literal::String("abcdefg")); + ICEBERG_UNWRAP_OR_FAIL(auto bound_lt, + unbound_lt->Bind(*string_schema_, /*case_sensitive=*/true)); + auto bound_pred_lt = std::dynamic_pointer_cast(bound_lt); + ASSERT_NE(bound_pred_lt, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected_lt, + transform->ProjectStrict("part", bound_pred_lt)); + ASSERT_NE(projected_lt, nullptr); + EXPECT_EQ(projected_lt->op(), Expression::Operation::kLt); + + auto unbound_projected_lt = + internal::checked_pointer_cast>( + std::move(projected_lt)); + EXPECT_EQ(unbound_projected_lt->op(), Expression::Operation::kLt); + EXPECT_EQ(unbound_projected_lt->literals().size(), 1); + EXPECT_EQ(std::get(unbound_projected_lt->literals().front().value()), + "abcde"); +} + +TEST_F(TransformProjectStrictTest, YearStrictEqualityReturnsNull) { + auto transform = Transform::Year(); + + // Year strict projection: equality returns null (cannot guarantee) + int32_t date_value = + TemporalTestHelper::CreateDate({.year = 2021, .month = 6, .day = 1}); + auto unbound = Expressions::Equal("value", Literal::Date(date_value)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*date_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + EXPECT_EQ(projected, nullptr); +} + +TEST_F(TransformProjectStrictTest, YearStrictLessThan) { + auto transform = Transform::Year(); + + // Year strict projection: LT projects to LT + int32_t date_value = + TemporalTestHelper::CreateDate({.year = 2021, .month = 1, .day = 1}); + auto unbound = Expressions::LessThan("value", Literal::Date(date_value)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*date_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kLt); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kLt); + EXPECT_EQ(unbound_projected->literals().size(), 1); + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 2021); +} + +TEST_F(TransformProjectStrictTest, YearStrictGreaterThanOrEqual) { + auto transform = Transform::Year(); + + // Year strict projection: GTE projects to GT (lower bound) + int32_t date_value = + TemporalTestHelper::CreateDate({.year = 2021, .month = 1, .day = 1}); + auto unbound = Expressions::GreaterThanOrEqual("value", Literal::Date(date_value)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*date_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kGt); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kGt); + EXPECT_EQ(unbound_projected->literals().size(), 1); + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 2020); +} + +TEST_F(TransformProjectStrictTest, YearStrictNotEqual) { + auto transform = Transform::Year(); + + // Year strict projection: notEqual can be projected + int32_t date_value = + TemporalTestHelper::CreateDate({.year = 2021, .month = 1, .day = 1}); + auto unbound = Expressions::NotEqual("value", Literal::Date(date_value)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*date_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kNotEq); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kNotEq); + EXPECT_EQ(unbound_projected->literals().size(), 1); + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 2021); +} + +TEST_F(TransformProjectStrictTest, MonthStrictLessThan) { + auto transform = Transform::Month(); + + // Month strict projection: LT projects to LT + int64_t ts_value = + TemporalTestHelper::CreateTimestamp({.year = 2017, .month = 12, .day = 1}); + auto unbound = Expressions::LessThan("value", Literal::Timestamp(ts_value)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*timestamp_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kLt); +} + +TEST_F(TransformProjectStrictTest, DayStrictLessThan) { + auto transform = Transform::Day(); + + // Day strict projection: LT projects to LT + int64_t ts_value = + TemporalTestHelper::CreateTimestamp({.year = 2017, .month = 12, .day = 1}); + auto unbound = Expressions::LessThan("value", Literal::Timestamp(ts_value)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*timestamp_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kLt); +} + +TEST_F(TransformProjectStrictTest, HourStrictLessThan) { + auto transform = Transform::Hour(); + + // Hour strict projection: LT projects to LT + int64_t ts_value = TemporalTestHelper::CreateTimestamp( + {.year = 2017, .month = 12, .day = 1, .hour = 10, .minute = 0}); + auto unbound = Expressions::LessThan("value", Literal::Timestamp(ts_value)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*timestamp_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kLt); +} + +TEST_F(TransformProjectStrictTest, DayStrictEpoch) { + auto transform = Transform::Day(); + + // Day strict projection at epoch: LT projects to LT + auto unbound = Expressions::LessThan("value", Literal::Timestamp(0)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*timestamp_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kLt); +} + +TEST_F(TransformProjectStrictTest, MonthStrictNotEqualNegative) { + auto transform = Transform::Month(); + + // Month strict projection: notEqual with negative dates may convert to NOT_IN + int64_t ts_value = + TemporalTestHelper::CreateTimestamp({.year = 1969, .month = 1, .day = 1}); + auto unbound = Expressions::NotEqual("value", Literal::Timestamp(ts_value)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*timestamp_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + // For negative dates, NOT_EQ may convert to NOT_IN + EXPECT_TRUE(projected->op() == Expression::Operation::kNotEq || + projected->op() == Expression::Operation::kNotIn); +} + +TEST_F(TransformProjectStrictTest, YearStrictUpperBound) { + auto transform = Transform::Year(); + + // Year strict projection: upper bound (end of year) + int32_t date_value = + TemporalTestHelper::CreateDate({.year = 2017, .month = 12, .day = 31}); + auto unbound = Expressions::LessThanOrEqual("value", Literal::Date(date_value)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*date_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + ASSERT_NE(projected, nullptr); + EXPECT_EQ(projected->op(), Expression::Operation::kLt); + + auto unbound_projected = + internal::checked_pointer_cast>( + std::move(projected)); + EXPECT_EQ(unbound_projected->op(), Expression::Operation::kLt); + EXPECT_EQ(unbound_projected->literals().size(), 1); + EXPECT_EQ(std::get(unbound_projected->literals().front().value()), 2018); +} + +TEST_F(TransformProjectStrictTest, VoidStrictReturnsNull) { + auto transform = Transform::Void(); + + // Void transform always returns null for strict projection + auto unbound = Expressions::Equal("value", Literal::Int(100)); + ICEBERG_UNWRAP_OR_FAIL(auto bound, + unbound->Bind(*int_schema_, /*case_sensitive=*/true)); + auto bound_pred = std::dynamic_pointer_cast(bound); + ASSERT_NE(bound_pred, nullptr); + + ICEBERG_UNWRAP_OR_FAIL(auto projected, transform->ProjectStrict("part", bound_pred)); + EXPECT_EQ(projected, nullptr); +} + } // namespace iceberg diff --git a/src/iceberg/transaction.h b/src/iceberg/transaction.h index 0bcedd6d8..72ba5182c 100644 --- a/src/iceberg/transaction.h +++ b/src/iceberg/transaction.h @@ -21,7 +21,6 @@ #pragma once #include -#include #include "iceberg/iceberg_export.h" #include "iceberg/result.h" diff --git a/src/iceberg/transform.cc b/src/iceberg/transform.cc index f8d2f0655..614489710 100644 --- a/src/iceberg/transform.cc +++ b/src/iceberg/transform.cc @@ -306,6 +306,66 @@ Result> Transform::Project( std::unreachable(); } +Result> Transform::ProjectStrict( + std::string_view name, const std::shared_ptr& predicate) { + switch (transform_type_) { + case TransformType::kIdentity: + return ProjectionUtil::IdentityProject(name, predicate); + case TransformType::kBucket: { + // If the predicate has a transformed child that matches the given transform, return + // a predicate. + if (predicate->term()->kind() == Term::Kind::kTransform) { + const auto boundTransform = + internal::checked_pointer_cast(predicate->term()); + if (*this == *boundTransform->transform()) { + return ProjectionUtil::RemoveTransform(name, predicate); + } else { + return nullptr; + } + } + ICEBERG_ASSIGN_OR_RAISE(auto func, Bind(predicate->term()->type())); + return ProjectionUtil::BucketProjectStrict(name, predicate, func); + } + case TransformType::kTruncate: { + // If the predicate has a transformed child that matches the given transform, return + // a predicate. + if (predicate->term()->kind() == Term::Kind::kTransform) { + const auto boundTransform = + internal::checked_pointer_cast(predicate->term()); + if (*this == *boundTransform->transform()) { + return ProjectionUtil::RemoveTransform(name, predicate); + } else { + return nullptr; + } + } + ICEBERG_ASSIGN_OR_RAISE(auto func, Bind(predicate->term()->type())); + return ProjectionUtil::TruncateProjectStrict(name, predicate, func); + } + case TransformType::kYear: + case TransformType::kMonth: + case TransformType::kDay: + case TransformType::kHour: { + // If the predicate has a transformed child that matches the given transform, return + // a predicate. + if (predicate->term()->kind() == Term::Kind::kTransform) { + const auto boundTransform = + internal::checked_pointer_cast(predicate->term()); + if (*this == *boundTransform->transform()) { + return ProjectionUtil::RemoveTransform(name, predicate); + } else { + return nullptr; + } + } + ICEBERG_ASSIGN_OR_RAISE(auto func, Bind(predicate->term()->type())); + return ProjectionUtil::TemporalProjectStrict(name, predicate, func); + } + case TransformType::kUnknown: + case TransformType::kVoid: + return nullptr; + } + std::unreachable(); +} + bool TransformFunction::Equals(const TransformFunction& other) const { return transform_type_ == other.transform_type_ && *source_type_ == *other.source_type_; } diff --git a/src/iceberg/transform.h b/src/iceberg/transform.h index 64b850725..53993b4e3 100644 --- a/src/iceberg/transform.h +++ b/src/iceberg/transform.h @@ -182,6 +182,18 @@ class ICEBERG_EXPORT Transform : public util::Formattable { Result> Project( std::string_view name, const std::shared_ptr& predicate); + /// \brief Transforms a BoundPredicate to a strict predicate on the partition values + /// produced by the transform. + /// + /// This strict transform guarantees that if Projected(transform(value)) is true, then + /// predicate->Test(value) is also true. + /// \param name The name of the partition column. + /// \param predicate The predicate to project. + /// \return A Result containing either a unique pointer to the projected predicate, + /// nullptr if the projection cannot be performed, or an Error if the projection fails. + Result> ProjectStrict( + std::string_view name, const std::shared_ptr& predicate); + /// \brief Returns a string representation of this transform (e.g., "bucket[16]"). std::string ToString() const override; diff --git a/src/iceberg/util/error_collector.h b/src/iceberg/util/error_collector.h index f94967f97..48ea717eb 100644 --- a/src/iceberg/util/error_collector.h +++ b/src/iceberg/util/error_collector.h @@ -30,6 +30,21 @@ namespace iceberg { +#define BUILDER_RETURN_IF_ERROR(result) \ + if (auto&& result_name = result; !result_name) [[unlikely]] { \ + errors_.emplace_back(std::move(result_name.error())); \ + return *this; \ + } + +#define BUILDER_ASSIGN_OR_RETURN_IMPL(result_name, lhs, rexpr) \ + auto&& result_name = (rexpr); \ + BUILDER_RETURN_IF_ERROR(result_name) \ + lhs = std::move(result_name.value()); + +#define BUILDER_ASSIGN_OR_RETURN(lhs, rexpr) \ + BUILDER_ASSIGN_OR_RETURN_IMPL(ICEBERG_ASSIGN_OR_RAISE_NAME(result_, __COUNTER__), lhs, \ + rexpr) + /// \brief Base class for collecting validation errors in builder patterns /// /// This class provides error accumulation functionality for builders that diff --git a/src/iceberg/util/projection_util_internal.h b/src/iceberg/util/projection_util_internal.h index 3ce2dbf8f..df4fe9789 100644 --- a/src/iceberg/util/projection_util_internal.h +++ b/src/iceberg/util/projection_util_internal.h @@ -24,10 +24,12 @@ #include #include #include +#include #include #include "iceberg/expression/literal.h" #include "iceberg/expression/predicate.h" +#include "iceberg/expression/term.h" #include "iceberg/result.h" #include "iceberg/transform.h" #include "iceberg/transform_function.h" @@ -40,248 +42,230 @@ namespace iceberg { class ProjectionUtil { private: + static Result AdjustLiteral(const Literal& literal, int adjustment) { + switch (literal.type()->type_id()) { + case TypeId::kInt: + return Literal::Int(std::get(literal.value()) + adjustment); + case TypeId::kLong: + return Literal::Long(std::get(literal.value()) + adjustment); + case TypeId::kDate: + return Literal::Date(std::get(literal.value()) + adjustment); + case TypeId::kTimestamp: + return Literal::Timestamp(std::get(literal.value()) + adjustment); + case TypeId::kTimestampTz: + return Literal::TimestampTz(std::get(literal.value()) + adjustment); + case TypeId::kDecimal: { + const auto& decimal_type = + internal::checked_cast(*literal.type()); + Decimal adjusted = std::get(literal.value()) + Decimal(adjustment); + return Literal::Decimal(adjusted.value(), decimal_type.precision(), + decimal_type.scale()); + } + default: + return NotSupported("{} is not a valid literal type for value adjustment", + literal.type()->ToString()); + } + } + + static Result PlusOne(const Literal& literal) { + return AdjustLiteral(literal, /*adjustment=*/+1); + } + + static Result MinusOne(const Literal& literal) { + return AdjustLiteral(literal, /*adjustment=*/-1); + } + + static Result> MakePredicate( + Expression::Operation op, std::string_view name, + const std::shared_ptr& func, const Literal& literal) { + ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); + ICEBERG_ASSIGN_OR_RAISE(auto lit, func->Transform(literal)); + return UnboundPredicateImpl::Make(op, std::move(ref), std::move(lit)); + } + static Result> TransformSet( - std::string_view name, const std::shared_ptr& predicate, + std::string_view name, const std::shared_ptr& pred, const std::shared_ptr& func) { std::vector transformed; - transformed.reserve(predicate->literal_set().size()); - for (const auto& lit : predicate->literal_set()) { + transformed.reserve(pred->literal_set().size()); + for (const auto& lit : pred->literal_set()) { ICEBERG_ASSIGN_OR_RAISE(auto transformed_lit, func->Transform(lit)); transformed.push_back(std::move(transformed_lit)); } ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); - return UnboundPredicateImpl::Make(predicate->op(), std::move(ref), + return UnboundPredicateImpl::Make(pred->op(), std::move(ref), std::move(transformed)); } - // General transform for all literal predicates. This is used as a fallback for special - // cases that are not handled by the other transform functions. - static Result> GenericTransform( - std::unique_ptr ref, - const std::shared_ptr& predicate, + static Result> TruncateByteArray( + std::string_view name, const std::shared_ptr& pred, const std::shared_ptr& func) { - ICEBERG_ASSIGN_OR_RAISE(auto transformed, func->Transform(predicate->literal())); - switch (predicate->op()) { + switch (pred->op()) { case Expression::Operation::kLt: - case Expression::Operation::kLtEq: { - return UnboundPredicateImpl::Make( - Expression::Operation::kLtEq, std::move(ref), std::move(transformed)); - } + case Expression::Operation::kLtEq: + return MakePredicate(Expression::Operation::kLtEq, name, func, pred->literal()); case Expression::Operation::kGt: - case Expression::Operation::kGtEq: { - return UnboundPredicateImpl::Make( - Expression::Operation::kGtEq, std::move(ref), std::move(transformed)); - } - case Expression::Operation::kEq: { - return UnboundPredicateImpl::Make( - Expression::Operation::kEq, std::move(ref), std::move(transformed)); - } + case Expression::Operation::kGtEq: + return MakePredicate(Expression::Operation::kGtEq, name, func, pred->literal()); + case Expression::Operation::kEq: + case Expression::Operation::kStartsWith: + return MakePredicate(pred->op(), name, func, pred->literal()); default: return nullptr; } } - static Result> TruncateByteArray( - std::string_view name, const std::shared_ptr& predicate, + static Result> TruncateByteArrayStrict( + std::string_view name, const std::shared_ptr& pred, const std::shared_ptr& func) { - ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); - switch (predicate->op()) { - case Expression::Operation::kStartsWith: { - ICEBERG_ASSIGN_OR_RAISE(auto transformed, func->Transform(predicate->literal())); - return UnboundPredicateImpl::Make( - Expression::Operation::kStartsWith, std::move(ref), std::move(transformed)); - } + switch (pred->op()) { + case Expression::Operation::kLt: + case Expression::Operation::kLtEq: + return MakePredicate(Expression::Operation::kLt, name, func, pred->literal()); + case Expression::Operation::kGt: + case Expression::Operation::kGtEq: + return MakePredicate(Expression::Operation::kGt, name, func, pred->literal()); + case Expression::Operation::kNotEq: + return MakePredicate(Expression::Operation::kNotEq, name, func, pred->literal()); default: - return GenericTransform(std::move(ref), predicate, func); + return nullptr; } } - template - requires std::is_same_v || std::is_same_v - static Result> TruncateInteger( - std::string_view name, const std::shared_ptr& predicate, + // Apply to int32, int64, decimal, and temporal types + static Result> TransformNumeric( + std::string_view name, const std::shared_ptr& pred, const std::shared_ptr& func) { - const Literal& literal = predicate->literal(); - ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); + switch (func->source_type()->type_id()) { + case TypeId::kInt: + case TypeId::kLong: + case TypeId::kDecimal: + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + break; + default: + return NotSupported("{} is not a valid input type for numeric transform", + func->source_type()->ToString()); + } - switch (predicate->op()) { + switch (pred->op()) { case Expression::Operation::kLt: { // adjust closed and then transform ltEq - if constexpr (std::is_same_v) { - ICEBERG_ASSIGN_OR_RAISE( - auto transformed, - func->Transform(Literal::Int(std::get(literal.value()) - 1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kLtEq, std::move(ref), std::move(transformed)); - } else { - ICEBERG_ASSIGN_OR_RAISE( - auto transformed, - func->Transform(Literal::Long(std::get(literal.value()) - 1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kLtEq, std::move(ref), std::move(transformed)); - } + ICEBERG_ASSIGN_OR_RAISE(auto adjusted, MinusOne(pred->literal())); + return MakePredicate(Expression::Operation::kLtEq, name, func, adjusted); } case Expression::Operation::kGt: { // adjust closed and then transform gtEq - if constexpr (std::is_same_v) { - ICEBERG_ASSIGN_OR_RAISE( - auto transformed, - func->Transform(Literal::Int(std::get(literal.value()) + 1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kGtEq, std::move(ref), std::move(transformed)); - } else { - ICEBERG_ASSIGN_OR_RAISE( - auto transformed, - func->Transform(Literal::Long(std::get(literal.value()) + 1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kGtEq, std::move(ref), std::move(transformed)); - } + ICEBERG_ASSIGN_OR_RAISE(auto adjusted, PlusOne(pred->literal())); + return MakePredicate(Expression::Operation::kGtEq, name, func, adjusted); } + case Expression::Operation::kLtEq: + case Expression::Operation::kGtEq: + case Expression::Operation::kEq: + return MakePredicate(pred->op(), name, func, pred->literal()); default: - return GenericTransform(std::move(ref), predicate, func); + return nullptr; } } - static Result> TransformTemporal( - std::string_view name, const std::shared_ptr& predicate, + static Result> TransformNumericStrict( + std::string_view name, const std::shared_ptr& pred, const std::shared_ptr& func) { - const Literal& literal = predicate->literal(); - ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); - switch (func->source_type()->type_id()) { - case TypeId::kDate: { - switch (predicate->op()) { - case Expression::Operation::kLt: { - ICEBERG_ASSIGN_OR_RAISE( - auto transformed, - func->Transform(Literal::Date(std::get(literal.value()) - 1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kLtEq, std::move(ref), std::move(transformed)); - } - case Expression::Operation::kGt: { - ICEBERG_ASSIGN_OR_RAISE( - auto transformed, - func->Transform(Literal::Date(std::get(literal.value()) + 1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kGtEq, std::move(ref), std::move(transformed)); - } - default: - return GenericTransform(std::move(ref), predicate, func); - } - } - case TypeId::kTimestamp: { - switch (predicate->op()) { - case Expression::Operation::kLt: { - ICEBERG_ASSIGN_OR_RAISE(auto transformed, - func->Transform(Literal::Timestamp( - std::get(literal.value()) - 1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kLtEq, std::move(ref), std::move(transformed)); - } - case Expression::Operation::kGt: { - ICEBERG_ASSIGN_OR_RAISE(auto transformed, - func->Transform(Literal::Timestamp( - std::get(literal.value()) + 1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kGtEq, std::move(ref), std::move(transformed)); - } - default: - return GenericTransform(std::move(ref), predicate, func); - } + case TypeId::kInt: + case TypeId::kLong: + case TypeId::kDecimal: + case TypeId::kDate: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + break; + default: + return NotSupported("{} is not a valid input type for numeric transform", + func->source_type()->ToString()); + } + + switch (pred->op()) { + case Expression::Operation::kLtEq: { + ICEBERG_ASSIGN_OR_RAISE(auto adjusted, PlusOne(pred->literal())); + return MakePredicate(Expression::Operation::kLt, name, func, adjusted); } - case TypeId::kTimestampTz: { - switch (predicate->op()) { - case Expression::Operation::kLt: { - ICEBERG_ASSIGN_OR_RAISE(auto transformed, - func->Transform(Literal::TimestampTz( - std::get(literal.value()) - 1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kLtEq, std::move(ref), std::move(transformed)); - } - case Expression::Operation::kGt: { - ICEBERG_ASSIGN_OR_RAISE(auto transformed, - func->Transform(Literal::TimestampTz( - std::get(literal.value()) + 1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kGtEq, std::move(ref), std::move(transformed)); - } - default: - return GenericTransform(std::move(ref), predicate, func); - } + case Expression::Operation::kGtEq: { + ICEBERG_ASSIGN_OR_RAISE(auto adjusted, MinusOne(pred->literal())); + return MakePredicate(Expression::Operation::kGt, name, func, adjusted); } + case Expression::Operation::kLt: + case Expression::Operation::kGt: + case Expression::Operation::kNotEq: + return MakePredicate(pred->op(), name, func, pred->literal()); default: - return NotSupported("{} is not a valid input type for temporal transform", - func->source_type()->ToString()); + return nullptr; } } - static Result> TruncateDecimal( - std::string_view name, const std::shared_ptr& predicate, + static Result> TruncateStringLiteral( + std::string_view name, const std::shared_ptr& pred, const std::shared_ptr& func) { - const Literal& boundary = predicate->literal(); - ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); + const auto op = pred->op(); + if (op != Expression::Operation::kStartsWith && + op != Expression::Operation::kNotStartsWith) { + return TruncateByteArray(name, pred, func); + } - // For boundary adjustments, extract type info once - auto make_adjusted_literal = [&boundary](int adjustment) { - const auto& type = internal::checked_pointer_cast(boundary.type()); - Decimal adjusted = std::get(boundary.value()) + Decimal(adjustment); - return Literal::Decimal(adjusted.value(), type->precision(), type->scale()); - }; + const auto& literal = pred->literal(); + const auto length = + StringUtils::CodePointCount(std::get(literal.value())); + const auto width = static_cast( + internal::checked_pointer_cast(func)->width()); - switch (predicate->op()) { - case Expression::Operation::kLt: { - // adjust closed and then transform ltEq - ICEBERG_ASSIGN_OR_RAISE(auto transformed, - func->Transform(make_adjusted_literal(-1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kLtEq, std::move(ref), std::move(transformed)); - } - case Expression::Operation::kGt: { - // adjust closed and then transform gtEq - ICEBERG_ASSIGN_OR_RAISE(auto transformed, - func->Transform(make_adjusted_literal(1))); - return UnboundPredicateImpl::Make( - Expression::Operation::kGtEq, std::move(ref), std::move(transformed)); + if (length < width) { + return MakePredicate(op, name, func, literal); + } + + if (length == width) { + if (op == Expression::Operation::kStartsWith) { + return MakePredicate(Expression::Operation::kEq, name, func, literal); + } else { + return MakePredicate(Expression::Operation::kNotEq, name, func, literal); } - default: - return GenericTransform(std::move(ref), predicate, func); } + + if (op == Expression::Operation::kStartsWith) { + return TruncateByteArray(name, pred, func); + } + + return nullptr; } - static Result> TruncateStringLiteral( - std::string_view name, const std::shared_ptr& predicate, + static Result> TruncateStringLiteralStrict( + std::string_view name, const std::shared_ptr& pred, const std::shared_ptr& func) { - const auto op = predicate->op(); + const auto op = pred->op(); if (op != Expression::Operation::kStartsWith && op != Expression::Operation::kNotStartsWith) { - return TruncateByteArray(name, predicate, func); + return TruncateByteArrayStrict(name, pred, func); } - const auto& truncate_transform = - internal::checked_pointer_cast(func); - const auto& str_value = std::get(predicate->literal().value()); - const auto width = truncate_transform->width(); - ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); + const auto& literal = pred->literal(); + const auto length = + StringUtils::CodePointCount(std::get(literal.value())); + const auto width = static_cast( + internal::checked_pointer_cast(func)->width()); - if (StringUtils::CodePointCount(str_value) < width) { - return UnboundPredicateImpl::Make(op, std::move(ref), - predicate->literal()); + if (length < width) { + return MakePredicate(op, name, func, literal); } - if (StringUtils::CodePointCount(str_value) == width) { + if (length == width) { if (op == Expression::Operation::kStartsWith) { - return UnboundPredicateImpl::Make( - Expression::Operation::kEq, std::move(ref), predicate->literal()); + return MakePredicate(Expression::Operation::kEq, name, func, literal); } else { - return UnboundPredicateImpl::Make( - Expression::Operation::kNotEq, std::move(ref), predicate->literal()); + return MakePredicate(Expression::Operation::kNotEq, name, func, literal); } } - if (op == Expression::Operation::kStartsWith) { - ICEBERG_ASSIGN_OR_RAISE(auto transformed, func->Transform(predicate->literal())); - return UnboundPredicateImpl::Make( - Expression::Operation::kStartsWith, std::move(ref), std::move(transformed)); + if (op == Expression::Operation::kNotStartsWith) { + return MakePredicate(Expression::Operation::kNotStartsWith, name, func, literal); } return nullptr; @@ -304,14 +288,13 @@ class ProjectionUtil { const auto& literal = projected->literals().front(); ICEBERG_DCHECK(std::holds_alternative(literal.value()), "Expected int32_t"); - auto value = std::get(literal.value()); - if (value < 0) { + if (auto value = std::get(literal.value()); value < 0) { return UnboundPredicateImpl::Make(Expression::Operation::kLt, std::move(projected->term()), Literal::Int(value + 1)); } - return std::move(projected); + return projected; } case Expression::Operation::kLtEq: { @@ -319,34 +302,33 @@ class ProjectionUtil { const auto& literal = projected->literals().front(); ICEBERG_DCHECK(std::holds_alternative(literal.value()), "Expected int32_t"); - auto value = std::get(literal.value()); - if (value < 0) { + + if (auto value = std::get(literal.value()); value < 0) { return UnboundPredicateImpl::Make(Expression::Operation::kLtEq, std::move(projected->term()), Literal::Int(value + 1)); } - return std::move(projected); + return projected; } case Expression::Operation::kGt: case Expression::Operation::kGtEq: // incorrect projected values are already greater than the bound for GT, GT_EQ - return std::move(projected); + return projected; case Expression::Operation::kEq: { ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal"); const auto& literal = projected->literals().front(); ICEBERG_DCHECK(std::holds_alternative(literal.value()), "Expected int32_t"); - auto value = std::get(literal.value()); - if (value < 0) { + if (auto value = std::get(literal.value()); value < 0) { // match either the incorrect value (projectedValue + 1) or the correct value // (projectedValue) return UnboundPredicateImpl::Make( Expression::Operation::kIn, std::move(projected->term()), {literal, Literal::Int(value + 1)}); } - return std::move(projected); + return projected; } case Expression::Operation::kIn: { @@ -377,7 +359,7 @@ class ProjectionUtil { std::move(projected->term()), std::move(values)); } - return std::move(projected); + return projected; } case Expression::Operation::kNotIn: @@ -386,30 +368,128 @@ class ProjectionUtil { return nullptr; default: - return std::move(projected); + return projected; + } + } + + // Fixes a strict projection to account for incorrectly transformed values. + // align with Java implementation: + // https://github.com/apache/iceberg/blob/1.10.x/api/src/main/java/org/apache/iceberg/transforms/ProjectionUtil.java#L347 + static Result> FixStrictTimeProjection( + std::unique_ptr> projected) { + if (projected == nullptr) { + return nullptr; + } + + switch (projected->op()) { + case Expression::Operation::kLt: + case Expression::Operation::kLtEq: + // the correct bound is a correct strict projection for the incorrectly + // transformed values. + return projected; + + case Expression::Operation::kGt: { + // GT and GT_EQ need to be adjusted because values that do not match the predicate + // may have been transformed into partition values that match the projected + // predicate. + ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal"); + const auto& literal = projected->literals().front(); + ICEBERG_DCHECK(std::holds_alternative(literal.value()), + "Expected int32_t"); + if (auto value = std::get(literal.value()); value <= 0) { + return UnboundPredicateImpl::Make(Expression::Operation::kGt, + std::move(projected->term()), + Literal::Int(value + 1)); + } + return projected; + } + + case Expression::Operation::kGtEq: { + ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal"); + const auto& literal = projected->literals().front(); + ICEBERG_DCHECK(std::holds_alternative(literal.value()), + "Expected int32_t"); + if (auto value = std::get(literal.value()); value <= 0) { + return UnboundPredicateImpl::Make(Expression::Operation::kGtEq, + std::move(projected->term()), + Literal::Int(value + 1)); + } + return projected; + } + + case Expression::Operation::kEq: + case Expression::Operation::kIn: + // there is no strict projection for EQ and IN + return nullptr; + + case Expression::Operation::kNotEq: { + ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal"); + const auto& literal = projected->literals().front(); + ICEBERG_DCHECK(std::holds_alternative(literal.value()), + "Expected int32_t"); + if (auto value = std::get(literal.value()); value < 0) { + return UnboundPredicateImpl::Make( + Expression::Operation::kNotIn, std::move(projected->term()), + {literal, Literal::Int(value + 1)}); + } + return projected; + } + + case Expression::Operation::kNotIn: { + ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal"); + const auto& literals = projected->literals(); + ICEBERG_DCHECK( + std::ranges::all_of(literals, + [](const auto& lit) { + return std::holds_alternative(lit.value()); + }), + "Expected int32_t"); + std::unordered_set value_set; + bool has_negative_value = false; + for (const auto& lit : literals) { + auto value = std::get(lit.value()); + value_set.insert(value); + if (value < 0) { + value_set.insert(value + 1); + has_negative_value = true; + } + } + if (has_negative_value) { + auto values = + std::views::transform(value_set, + [](int32_t value) { return Literal::Int(value); }) | + std::ranges::to(); + return UnboundPredicateImpl::Make(Expression::Operation::kNotIn, + std::move(projected->term()), + std::move(values)); + } + return projected; + } + + default: + return nullptr; } } public: static Result> IdentityProject( - std::string_view name, const std::shared_ptr& predicate) { + std::string_view name, const std::shared_ptr& pred) { ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); - switch (predicate->kind()) { + switch (pred->kind()) { case BoundPredicate::Kind::kUnary: { - return UnboundPredicateImpl::Make(predicate->op(), - std::move(ref)); + return UnboundPredicateImpl::Make(pred->op(), std::move(ref)); } case BoundPredicate::Kind::kLiteral: { const auto& literalPredicate = - internal::checked_pointer_cast(predicate); - return UnboundPredicateImpl::Make(predicate->op(), std::move(ref), + internal::checked_pointer_cast(pred); + return UnboundPredicateImpl::Make(pred->op(), std::move(ref), literalPredicate->literal()); } case BoundPredicate::Kind::kSet: { const auto& setPredicate = - internal::checked_pointer_cast(predicate); + internal::checked_pointer_cast(pred); return UnboundPredicateImpl::Make( - predicate->op(), std::move(ref), + pred->op(), std::move(ref), std::vector(setPredicate->literal_set().begin(), setPredicate->literal_set().end())); } @@ -418,30 +498,29 @@ class ProjectionUtil { } static Result> BucketProject( - std::string_view name, const std::shared_ptr& predicate, + std::string_view name, const std::shared_ptr& pred, const std::shared_ptr& func) { ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); - switch (predicate->kind()) { + switch (pred->kind()) { case BoundPredicate::Kind::kUnary: { - return UnboundPredicateImpl::Make(predicate->op(), - std::move(ref)); + return UnboundPredicateImpl::Make(pred->op(), std::move(ref)); } case BoundPredicate::Kind::kLiteral: { - if (predicate->op() == Expression::Operation::kEq) { + if (pred->op() == Expression::Operation::kEq) { const auto& literalPredicate = - internal::checked_pointer_cast(predicate); + internal::checked_pointer_cast(pred); ICEBERG_ASSIGN_OR_RAISE(auto transformed, func->Transform(literalPredicate->literal())); - return UnboundPredicateImpl::Make( - predicate->op(), std::move(ref), std::move(transformed)); + return UnboundPredicateImpl::Make(pred->op(), std::move(ref), + std::move(transformed)); } break; } case BoundPredicate::Kind::kSet: { // notIn can't be projected - if (predicate->op() == Expression::Operation::kIn) { + if (pred->op() == Expression::Operation::kIn) { const auto& setPredicate = - internal::checked_pointer_cast(predicate); + internal::checked_pointer_cast(pred); return TransformSet(name, setPredicate, func); } break; @@ -455,19 +534,19 @@ class ProjectionUtil { } static Result> TruncateProject( - std::string_view name, const std::shared_ptr& predicate, + std::string_view name, const std::shared_ptr& pred, const std::shared_ptr& func) { ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); // Handle unary predicates uniformly for all types - if (predicate->kind() == BoundPredicate::Kind::kUnary) { - return UnboundPredicateImpl::Make(predicate->op(), std::move(ref)); + if (pred->kind() == BoundPredicate::Kind::kUnary) { + return UnboundPredicateImpl::Make(pred->op(), std::move(ref)); } // Handle set predicates (kIn) uniformly for all types - if (predicate->kind() == BoundPredicate::Kind::kSet) { - if (predicate->op() == Expression::Operation::kIn) { + if (pred->kind() == BoundPredicate::Kind::kSet) { + if (pred->op() == Expression::Operation::kIn) { const auto& setPredicate = - internal::checked_pointer_cast(predicate); + internal::checked_pointer_cast(pred); return TransformSet(name, setPredicate, func); } return nullptr; @@ -475,15 +554,13 @@ class ProjectionUtil { // Handle literal predicates based on source type const auto& literalPredicate = - internal::checked_pointer_cast(predicate); + internal::checked_pointer_cast(pred); switch (func->source_type()->type_id()) { case TypeId::kInt: - return TruncateInteger(name, literalPredicate, func); case TypeId::kLong: - return TruncateInteger(name, literalPredicate, func); case TypeId::kDecimal: - return TruncateDecimal(name, literalPredicate, func); + return TransformNumeric(name, literalPredicate, func); case TypeId::kString: return TruncateStringLiteral(name, literalPredicate, func); case TypeId::kBinary: @@ -495,16 +572,16 @@ class ProjectionUtil { } static Result> TemporalProject( - std::string_view name, const std::shared_ptr& predicate, + std::string_view name, const std::shared_ptr& pred, const std::shared_ptr& func) { ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); - if (predicate->kind() == BoundPredicate::Kind::kUnary) { - return UnboundPredicateImpl::Make(predicate->op(), std::move(ref)); - } else if (predicate->kind() == BoundPredicate::Kind::kLiteral) { + if (pred->kind() == BoundPredicate::Kind::kUnary) { + return UnboundPredicateImpl::Make(pred->op(), std::move(ref)); + } else if (pred->kind() == BoundPredicate::Kind::kLiteral) { const auto& literalPredicate = - internal::checked_pointer_cast(predicate); + internal::checked_pointer_cast(pred); ICEBERG_ASSIGN_OR_RAISE(auto projected, - TransformTemporal(name, literalPredicate, func)); + TransformNumeric(name, literalPredicate, func)); if (func->transform_type() != TransformType::kDay || func->source_type()->type_id() != TypeId::kDate) { return FixInclusiveTimeProjection( @@ -512,10 +589,9 @@ class ProjectionUtil { std::move(projected))); } return projected; - } else if (predicate->kind() == BoundPredicate::Kind::kSet && - predicate->op() == Expression::Operation::kIn) { - const auto& setPredicate = - internal::checked_pointer_cast(predicate); + } else if (pred->kind() == BoundPredicate::Kind::kSet && + pred->op() == Expression::Operation::kIn) { + const auto& setPredicate = internal::checked_pointer_cast(pred); ICEBERG_ASSIGN_OR_RAISE(auto projected, TransformSet(name, setPredicate, func)); if (func->transform_type() != TransformType::kDay || func->source_type()->type_id() != TypeId::kDate) { @@ -530,30 +606,135 @@ class ProjectionUtil { } static Result> RemoveTransform( - std::string_view name, const std::shared_ptr& predicate) { + std::string_view name, const std::shared_ptr& pred) { ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); - switch (predicate->kind()) { + switch (pred->kind()) { case BoundPredicate::Kind::kUnary: { - return UnboundPredicateImpl::Make(predicate->op(), - std::move(ref)); + return UnboundPredicateImpl::Make(pred->op(), std::move(ref)); } case BoundPredicate::Kind::kLiteral: { const auto& literalPredicate = - internal::checked_pointer_cast(predicate); - return UnboundPredicateImpl::Make(predicate->op(), std::move(ref), + internal::checked_pointer_cast(pred); + return UnboundPredicateImpl::Make(pred->op(), std::move(ref), literalPredicate->literal()); } case BoundPredicate::Kind::kSet: { const auto& setPredicate = - internal::checked_pointer_cast(predicate); + internal::checked_pointer_cast(pred); return UnboundPredicateImpl::Make( - predicate->op(), std::move(ref), + pred->op(), std::move(ref), std::vector(setPredicate->literal_set().begin(), setPredicate->literal_set().end())); } } std::unreachable(); } + + static Result> BucketProjectStrict( + std::string_view name, const std::shared_ptr& pred, + const std::shared_ptr& func) { + ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); + switch (pred->kind()) { + case BoundPredicate::Kind::kUnary: { + return UnboundPredicateImpl::Make(pred->op(), std::move(ref)); + } + case BoundPredicate::Kind::kLiteral: { + if (pred->op() == Expression::Operation::kNotEq) { + const auto& literalPredicate = + internal::checked_pointer_cast(pred); + ICEBERG_ASSIGN_OR_RAISE(auto transformed, + func->Transform(literalPredicate->literal())); + // TODO(anyone): need to translate not(eq(...)) into notEq in expressions + return UnboundPredicateImpl::Make(pred->op(), std::move(ref), + std::move(transformed)); + } + break; + } + case BoundPredicate::Kind::kSet: { + if (pred->op() == Expression::Operation::kNotIn) { + const auto& setPredicate = + internal::checked_pointer_cast(pred); + return TransformSet(name, setPredicate, func); + } + break; + } + } + + // no strict projection for comparison or equality + return nullptr; + } + + static Result> TruncateProjectStrict( + std::string_view name, const std::shared_ptr& pred, + const std::shared_ptr& func) { + ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); + // Handle unary predicates uniformly for all types + if (pred->kind() == BoundPredicate::Kind::kUnary) { + return UnboundPredicateImpl::Make(pred->op(), std::move(ref)); + } + + // Handle set predicates (kNotIn) uniformly for all types + if (pred->kind() == BoundPredicate::Kind::kSet) { + if (pred->op() == Expression::Operation::kNotIn) { + const auto& setPredicate = + internal::checked_pointer_cast(pred); + return TransformSet(name, setPredicate, func); + } + return nullptr; + } + + // Handle literal predicates based on source type + const auto& literalPredicate = + internal::checked_pointer_cast(pred); + + switch (func->source_type()->type_id()) { + case TypeId::kInt: + case TypeId::kLong: + case TypeId::kDecimal: + return TransformNumericStrict(name, literalPredicate, func); + case TypeId::kString: + return TruncateStringLiteralStrict(name, literalPredicate, func); + case TypeId::kBinary: + return TruncateByteArrayStrict(name, literalPredicate, func); + default: + return NotSupported("{} is not a valid input type for truncate transform", + func->source_type()->ToString()); + } + } + + static Result> TemporalProjectStrict( + std::string_view name, const std::shared_ptr& pred, + const std::shared_ptr& func) { + ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name))); + if (pred->kind() == BoundPredicate::Kind::kUnary) { + return UnboundPredicateImpl::Make(pred->op(), std::move(ref)); + } else if (pred->kind() == BoundPredicate::Kind::kLiteral) { + const auto& literalPredicate = + internal::checked_pointer_cast(pred); + ICEBERG_ASSIGN_OR_RAISE(auto projected, + TransformNumericStrict(name, literalPredicate, func)); + if (func->transform_type() != TransformType::kDay || + func->source_type()->type_id() != TypeId::kDate) { + return FixStrictTimeProjection( + internal::checked_pointer_cast>( + std::move(projected))); + } + return projected; + } else if (pred->kind() == BoundPredicate::Kind::kSet && + pred->op() == Expression::Operation::kNotIn) { + const auto& setPredicate = internal::checked_pointer_cast(pred); + ICEBERG_ASSIGN_OR_RAISE(auto projected, TransformSet(name, setPredicate, func)); + if (func->transform_type() != TransformType::kDay || + func->source_type()->type_id() != TypeId::kDate) { + return FixStrictTimeProjection( + internal::checked_pointer_cast>( + std::move(projected))); + } + return projected; + } + + return nullptr; + } }; } // namespace iceberg