Skip to content

Commit 06adc3f

Browse files
author
nullccxsy
committed
feat(type): add insensitive way to find schemafield & test
1 parent 55b0436 commit 06adc3f

File tree

3 files changed

+91
-8
lines changed

3 files changed

+91
-8
lines changed

src/iceberg/type.cc

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@
1919

2020
#include "iceberg/type.h"
2121

22+
#include <algorithm>
23+
#include <cctype>
2224
#include <format>
25+
#include <functional>
2326
#include <iterator>
2427
#include <memory>
28+
#include <optional>
29+
#include <ranges>
30+
#include <string_view>
31+
#include <iceberg/schema_field.h>
2532

2633
#include "iceberg/exception.h"
2734
#include "iceberg/util/formatter.h" // IWYU pragma: keep
@@ -70,12 +77,19 @@ std::optional<std::reference_wrapper<const SchemaField>> StructType::GetFieldByN
7077
std::string_view name) const {
7178
// N.B. duplicate names are not permitted (looking at the Java
7279
// implementation) so there is nothing in particular we need to do here
73-
for (const auto& field : fields_) {
74-
if (field.name() == name) {
75-
return field;
76-
}
77-
}
78-
return std::nullopt;
80+
InitNameToIdMap();
81+
auto it = field_name_to_index_.find(std::string(name));
82+
if (it == field_name_to_index_.end()) return std::nullopt;
83+
return fields_[it->second];
84+
}
85+
std::optional<std::reference_wrapper<const SchemaField>> StructType::GetFieldByNameCaseInsensitive(
86+
std::string_view name) const {
87+
InitNameToIdMapCaseInsensitive();
88+
std::string lower_name(name);
89+
std::ranges::transform(lower_name, lower_name.begin(), ::tolower);
90+
auto it = caseinsensitive_field_name_to_index_.find(lower_name);
91+
if (it == caseinsensitive_field_name_to_index_.end()) return std::nullopt;
92+
return fields_[it->second];
7993
}
8094
bool StructType::Equals(const Type& other) const {
8195
if (other.type_id() != TypeId::kStruct) {
@@ -84,6 +98,26 @@ bool StructType::Equals(const Type& other) const {
8498
const auto& struct_ = static_cast<const StructType&>(other);
8599
return fields_ == struct_.fields_;
86100
}
101+
void StructType::InitNameToIdMap() const {
102+
if (!field_name_to_index_.empty()) {
103+
return;
104+
}
105+
106+
for (int i = 0; i < fields_.size(); i++) {
107+
field_name_to_index_[std::string(fields_[i].name())] = i;
108+
}
109+
}
110+
void StructType::InitNameToIdMapCaseInsensitive() const {
111+
if (!caseinsensitive_field_name_to_index_.empty()) {
112+
return;
113+
}
114+
115+
for (int i = 0; i < fields_.size(); i++) {
116+
std::string lowercase_name(fields_[i].name());
117+
std::ranges::transform(lowercase_name, lowercase_name.begin(), ::tolower);
118+
caseinsensitive_field_name_to_index_[lowercase_name] = i;
119+
}
120+
}
87121

88122
ListType::ListType(SchemaField element) : element_(std::move(element)) {
89123
if (element_.name() != kElementName) {
@@ -126,6 +160,15 @@ std::optional<std::reference_wrapper<const SchemaField>> ListType::GetFieldByNam
126160
}
127161
return std::nullopt;
128162
}
163+
std::optional<std::reference_wrapper<const SchemaField>> ListType::GetFieldByNameCaseInsensitive(
164+
std::string_view name) const {
165+
auto lower_name_view = name | std::views::transform(::tolower);
166+
auto lower_field_name = element_.name() | std::views::transform(::tolower);
167+
if (std::ranges::equal(lower_field_name, lower_name_view)) {
168+
return std::cref(element_);
169+
}
170+
return std::nullopt;
171+
}
129172
bool ListType::Equals(const Type& other) const {
130173
if (other.type_id() != TypeId::kList) {
131174
return false;
@@ -186,6 +229,18 @@ std::optional<std::reference_wrapper<const SchemaField>> MapType::GetFieldByName
186229
}
187230
return std::nullopt;
188231
}
232+
std::optional<std::reference_wrapper<const SchemaField>> MapType::GetFieldByNameCaseInsensitive(
233+
std::string_view name) const {
234+
auto lower_name_view = name | std::views::transform(::tolower);
235+
auto lower_key_view = kKeyName | std::views::transform(tolower);
236+
auto lower_value_view = kValueName | std::views::transform(tolower);
237+
if (std::ranges::equal(lower_key_view, lower_name_view)) {
238+
return key();
239+
} else if (std::ranges::equal(lower_value_view, lower_name_view)) {
240+
return value();
241+
}
242+
return std::nullopt;
243+
}
189244
bool MapType::Equals(const Type& other) const {
190245
if (other.type_id() != TypeId::kMap) {
191246
return false;

src/iceberg/type.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include <array>
2727
#include <cstdint>
28+
#include <functional>
2829
#include <memory>
2930
#include <optional>
3031
#include <span>
@@ -81,17 +82,22 @@ class ICEBERG_EXPORT NestedType : public Type {
8182
[[nodiscard]] virtual std::optional<std::reference_wrapper<const SchemaField>>
8283
GetFieldById(int32_t field_id) const = 0;
8384
/// \brief Get a field by index.
85+
8486
///
8587
/// \note This is O(1) complexity.
8688
[[nodiscard]] virtual std::optional<std::reference_wrapper<const SchemaField>>
8789
GetFieldByIndex(int32_t index) const = 0;
8890
/// \brief Get a field by name (case-sensitive). Behavior is undefined if
8991
/// the field name is not unique; prefer GetFieldById or GetFieldByIndex
9092
/// when possible.
93+
9194
///
92-
/// \note This is currently O(n) complexity.
95+
/// \note This is currently O(1) complexity.
9396
[[nodiscard]] virtual std::optional<std::reference_wrapper<const SchemaField>>
9497
GetFieldByName(std::string_view name) const = 0;
98+
99+
[[nodiscard]] virtual std::optional<std::reference_wrapper<const SchemaField>>
100+
GetFieldByNameCaseInsensitive(std::string_view name) const = 0;
95101
};
96102

97103
/// \defgroup type-nested Nested Types
@@ -115,12 +121,18 @@ class ICEBERG_EXPORT StructType : public NestedType {
115121
int32_t index) const override;
116122
std::optional<std::reference_wrapper<const SchemaField>> GetFieldByName(
117123
std::string_view name) const override;
124+
std::optional<std::reference_wrapper<const SchemaField>> GetFieldByNameCaseInsensitive(
125+
std::string_view name) const override;
126+
void InitNameToIdMap() const;
127+
void InitNameToIdMapCaseInsensitive() const;
118128

119129
protected:
120130
bool Equals(const Type& other) const override;
121131

122132
std::vector<SchemaField> fields_;
123133
std::unordered_map<int32_t, size_t> field_id_to_index_;
134+
mutable std::unordered_map<std::string, size_t> field_name_to_index_;
135+
mutable std::unordered_map<std::string, size_t> caseinsensitive_field_name_to_index_;
124136
};
125137

126138
/// \brief A data type representing a list of values.
@@ -146,6 +158,8 @@ class ICEBERG_EXPORT ListType : public NestedType {
146158
int32_t index) const override;
147159
std::optional<std::reference_wrapper<const SchemaField>> GetFieldByName(
148160
std::string_view name) const override;
161+
std::optional<std::reference_wrapper<const SchemaField>> GetFieldByNameCaseInsensitive(
162+
std::string_view name) const override;
149163

150164
protected:
151165
bool Equals(const Type& other) const override;
@@ -178,7 +192,9 @@ class ICEBERG_EXPORT MapType : public NestedType {
178192
int32_t index) const override;
179193
std::optional<std::reference_wrapper<const SchemaField>> GetFieldByName(
180194
std::string_view name) const override;
181-
195+
std::optional<std::reference_wrapper<const SchemaField>> GetFieldByNameCaseInsensitive(
196+
std::string_view name) const override;
197+
182198
protected:
183199
bool Equals(const Type& other) const override;
184200

test/type_test.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121

2222
#include <format>
2323
#include <memory>
24+
#include <optional>
2425
#include <string>
2526

2627
#include <gmock/gmock.h>
2728
#include <gtest/gtest.h>
2829

30+
#include "gmock/gmock.h"
2931
#include "iceberg/exception.h"
3032
#include "iceberg/util/formatter.h" // IWYU pragma: keep
3133

@@ -318,11 +320,13 @@ TEST(TypeTest, List) {
318320
ASSERT_THAT(list.GetFieldById(5), ::testing::Optional(field));
319321
ASSERT_THAT(list.GetFieldByIndex(0), ::testing::Optional(field));
320322
ASSERT_THAT(list.GetFieldByName("element"), ::testing::Optional(field));
323+
ASSERT_THAT(list.GetFieldByNameCaseInsensitive("ELEMENT"), ::testing::Optional(field));
321324

322325
ASSERT_EQ(std::nullopt, list.GetFieldById(0));
323326
ASSERT_EQ(std::nullopt, list.GetFieldByIndex(1));
324327
ASSERT_EQ(std::nullopt, list.GetFieldByIndex(-1));
325328
ASSERT_EQ(std::nullopt, list.GetFieldByName("foo"));
329+
ASSERT_EQ(std::nullopt, list.GetFieldByNameCaseInsensitive("FOO"));
326330
}
327331
ASSERT_THAT(
328332
[]() {
@@ -348,11 +352,16 @@ TEST(TypeTest, Map) {
348352
ASSERT_THAT(map.GetFieldByIndex(1), ::testing::Optional(value));
349353
ASSERT_THAT(map.GetFieldByName("key"), ::testing::Optional(key));
350354
ASSERT_THAT(map.GetFieldByName("value"), ::testing::Optional(value));
355+
ASSERT_THAT(map.GetFieldByNameCaseInsensitive("KEY"), ::testing::Optional(key));
356+
ASSERT_THAT(map.GetFieldByNameCaseInsensitive("VALUE"), ::testing::Optional(value));
357+
351358

352359
ASSERT_EQ(std::nullopt, map.GetFieldById(0));
353360
ASSERT_EQ(std::nullopt, map.GetFieldByIndex(2));
354361
ASSERT_EQ(std::nullopt, map.GetFieldByIndex(-1));
355362
ASSERT_EQ(std::nullopt, map.GetFieldByName("element"));
363+
ASSERT_EQ(std::nullopt, map.GetFieldByName(""));
364+
ASSERT_EQ(std::nullopt, map.GetFieldByNameCaseInsensitive(""));
356365
}
357366
ASSERT_THAT(
358367
[]() {
@@ -387,11 +396,14 @@ TEST(TypeTest, Struct) {
387396
ASSERT_THAT(struct_.GetFieldByIndex(1), ::testing::Optional(field2));
388397
ASSERT_THAT(struct_.GetFieldByName("foo"), ::testing::Optional(field1));
389398
ASSERT_THAT(struct_.GetFieldByName("bar"), ::testing::Optional(field2));
399+
ASSERT_THAT(struct_.GetFieldByNameCaseInsensitive("FOO"), ::testing::Optional(field1));
400+
ASSERT_THAT(struct_.GetFieldByNameCaseInsensitive("Bar"), ::testing::Optional(field2));
390401

391402
ASSERT_EQ(std::nullopt, struct_.GetFieldById(0));
392403
ASSERT_EQ(std::nullopt, struct_.GetFieldByIndex(2));
393404
ASSERT_EQ(std::nullopt, struct_.GetFieldByIndex(-1));
394405
ASSERT_EQ(std::nullopt, struct_.GetFieldByName("element"));
406+
ASSERT_EQ(std::nullopt, struct_.GetFieldByNameCaseInsensitive("element"));
395407
}
396408
ASSERT_THAT(
397409
[]() {

0 commit comments

Comments
 (0)