diff --git a/src/iceberg/type.cc b/src/iceberg/type.cc index e66f96daf..c8f15ac6f 100644 --- a/src/iceberg/type.cc +++ b/src/iceberg/type.cc @@ -28,6 +28,16 @@ namespace iceberg { +namespace { +bool StringEqualsCaseInsensitive(std::string_view lhs, std::string_view rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + return std::equal(lhs.begin(), lhs.end(), rhs.begin(), + [](char a, char b) { return std::tolower(a) == std::tolower(b); }); +} +} // namespace + StructType::StructType(std::vector fields) : fields_(std::move(fields)) { size_t index = 0; for (const auto& field : fields_) { @@ -59,6 +69,7 @@ std::optional> StructType::GetFieldByI if (it == field_id_to_index_.end()) return std::nullopt; return fields_[it->second]; } + std::optional> StructType::GetFieldByIndex( int32_t index) const { if (index < 0 || index >= static_cast(fields_.size())) { @@ -66,6 +77,7 @@ std::optional> StructType::GetFieldByI } return fields_[index]; } + std::optional> StructType::GetFieldByName( std::string_view name) const { // N.B. duplicate names are not permitted (looking at the Java @@ -77,6 +89,17 @@ std::optional> StructType::GetFieldByN } return std::nullopt; } + +std::optional> +StructType::GetFieldByNameCaseInsensitive(std::string_view name) const { + for (const auto& field : fields_) { + if (StringEqualsCaseInsensitive(field.name(), name)) { + return std::cref(field); + } + } + return std::nullopt; +} + bool StructType::Equals(const Type& other) const { if (other.type_id() != TypeId::kStruct) { return false; @@ -126,6 +149,15 @@ std::optional> ListType::GetFieldByNam } return std::nullopt; } + +std::optional> +ListType::GetFieldByNameCaseInsensitive(std::string_view name) const { + if (StringEqualsCaseInsensitive(element_.name(), name)) { + return std::cref(element_); + } + return std::nullopt; +} + bool ListType::Equals(const Type& other) const { if (other.type_id() != TypeId::kList) { return false; @@ -186,6 +218,17 @@ std::optional> MapType::GetFieldByName } return std::nullopt; } + +std::optional> +MapType::GetFieldByNameCaseInsensitive(std::string_view name) const { + if (StringEqualsCaseInsensitive(kKeyName, name)) { + return key(); + } else if (StringEqualsCaseInsensitive(kValueName, name)) { + return value(); + } + return std::nullopt; +} + bool MapType::Equals(const Type& other) const { if (other.type_id() != TypeId::kMap) { return false; diff --git a/src/iceberg/type.h b/src/iceberg/type.h index 78c0141b1..634829f1e 100644 --- a/src/iceberg/type.h +++ b/src/iceberg/type.h @@ -92,6 +92,9 @@ class ICEBERG_EXPORT NestedType : public Type { /// \note This is currently O(n) complexity. [[nodiscard]] virtual std::optional> GetFieldByName(std::string_view name) const = 0; + + [[nodiscard]] virtual std::optional> + GetFieldByNameCaseInsensitive(std::string_view name) const = 0; }; /// \defgroup type-nested Nested Types @@ -115,6 +118,8 @@ class ICEBERG_EXPORT StructType : public NestedType { int32_t index) const override; std::optional> GetFieldByName( std::string_view name) const override; + std::optional> GetFieldByNameCaseInsensitive( + std::string_view name) const override; protected: bool Equals(const Type& other) const override; @@ -146,6 +151,8 @@ class ICEBERG_EXPORT ListType : public NestedType { int32_t index) const override; std::optional> GetFieldByName( std::string_view name) const override; + std::optional> GetFieldByNameCaseInsensitive( + std::string_view name) const override; protected: bool Equals(const Type& other) const override; @@ -178,6 +185,8 @@ class ICEBERG_EXPORT MapType : public NestedType { int32_t index) const override; std::optional> GetFieldByName( std::string_view name) const override; + std::optional> GetFieldByNameCaseInsensitive( + std::string_view name) const override; protected: bool Equals(const Type& other) const override; diff --git a/test/type_test.cc b/test/type_test.cc index fca886928..cc8905ab5 100644 --- a/test/type_test.cc +++ b/test/type_test.cc @@ -318,11 +318,17 @@ TEST(TypeTest, List) { ASSERT_THAT(list.GetFieldById(5), ::testing::Optional(field)); ASSERT_THAT(list.GetFieldByIndex(0), ::testing::Optional(field)); ASSERT_THAT(list.GetFieldByName("element"), ::testing::Optional(field)); + ASSERT_THAT(list.GetFieldByNameCaseInsensitive("element"), + ::testing::Optional(field)); + ASSERT_THAT(list.GetFieldByNameCaseInsensitive("ELEMENT"), + ::testing::Optional(field)); ASSERT_EQ(std::nullopt, list.GetFieldById(0)); ASSERT_EQ(std::nullopt, list.GetFieldByIndex(1)); ASSERT_EQ(std::nullopt, list.GetFieldByIndex(-1)); ASSERT_EQ(std::nullopt, list.GetFieldByName("foo")); + ASSERT_EQ(std::nullopt, list.GetFieldByNameCaseInsensitive("foo")); + ASSERT_EQ(std::nullopt, list.GetFieldByNameCaseInsensitive("FOO")); } ASSERT_THAT( []() { @@ -347,12 +353,15 @@ TEST(TypeTest, Map) { ASSERT_THAT(map.GetFieldByIndex(0), ::testing::Optional(key)); ASSERT_THAT(map.GetFieldByIndex(1), ::testing::Optional(value)); ASSERT_THAT(map.GetFieldByName("key"), ::testing::Optional(key)); + ASSERT_THAT(map.GetFieldByNameCaseInsensitive("kEY"), ::testing::Optional(key)); ASSERT_THAT(map.GetFieldByName("value"), ::testing::Optional(value)); + ASSERT_THAT(map.GetFieldByName("vALUE"), ::testing::Optional(value)); ASSERT_EQ(std::nullopt, map.GetFieldById(0)); ASSERT_EQ(std::nullopt, map.GetFieldByIndex(2)); ASSERT_EQ(std::nullopt, map.GetFieldByIndex(-1)); ASSERT_EQ(std::nullopt, map.GetFieldByName("element")); + ASSERT_EQ(std::nullopt, map.GetFieldByName("elemENt")); } ASSERT_THAT( []() { @@ -386,12 +395,17 @@ TEST(TypeTest, Struct) { ASSERT_THAT(struct_.GetFieldByIndex(0), ::testing::Optional(field1)); ASSERT_THAT(struct_.GetFieldByIndex(1), ::testing::Optional(field2)); ASSERT_THAT(struct_.GetFieldByName("foo"), ::testing::Optional(field1)); + ASSERT_THAT(struct_.GetFieldByName("FOO"), ::testing::Optional(field1)); ASSERT_THAT(struct_.GetFieldByName("bar"), ::testing::Optional(field2)); + ASSERT_THAT(struct_.GetFieldByNameCaseInsensitive("bar"), + ::testing::Optional(field2)); ASSERT_EQ(std::nullopt, struct_.GetFieldById(0)); ASSERT_EQ(std::nullopt, struct_.GetFieldByIndex(2)); ASSERT_EQ(std::nullopt, struct_.GetFieldByIndex(-1)); ASSERT_EQ(std::nullopt, struct_.GetFieldByName("element")); + ASSERT_EQ(std::nullopt, struct_.GetFieldByNameCaseInsensitive("element")); + ASSERT_EQ(std::nullopt, struct_.GetFieldByNameCaseInsensitive("ELEMENT")); } ASSERT_THAT( []() {