Skip to content

Commit cb4998c

Browse files
nullccxsynullccxsy
andauthored
feat: lazy init for Schema and StructType (#227)
…Type - Added move and copy constructors and assignment operators for Schema and StructType to manage resource ownership and improve performance. - Refactored field lookup methods to utilize lazy initialization with thread safety, ensuring safe concurrent access. - Introduced unit tests for thread safety in Schema and StructType, validating concurrent operations and access patterns. --------- Co-authored-by: nullccxsy <[email protected]>
1 parent b622584 commit cb4998c

File tree

8 files changed

+212
-28
lines changed

8 files changed

+212
-28
lines changed

src/iceberg/manifest_list.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ struct ICEBERG_EXPORT ManifestFile {
185185
507, "partitions",
186186
std::make_shared<ListType>(SchemaField::MakeRequired(
187187
508, std::string(ListType::kElementName),
188-
std::make_shared<StructType>(PartitionFieldSummary::Type()))),
188+
struct_(
189+
{PartitionFieldSummary::kContainsNull, PartitionFieldSummary::kContainsNaN,
190+
PartitionFieldSummary::kLowerBound, PartitionFieldSummary::kUpperBound}))),
189191
"Summary for each partition");
190192
inline static const SchemaField kKeyMetadata = SchemaField::MakeOptional(
191193
519, "key_metadata", iceberg::binary(), "Encryption key metadata blob");

src/iceberg/schema.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,14 @@ bool Schema::Equals(const Schema& other) const {
8989
Result<std::optional<std::reference_wrapper<const SchemaField>>> Schema::FindFieldByName(
9090
std::string_view name, bool case_sensitive) const {
9191
if (case_sensitive) {
92-
ICEBERG_RETURN_UNEXPECTED(InitNameToIdMap());
92+
ICEBERG_RETURN_UNEXPECTED(
93+
LazyInitWithCallOnce(name_to_id_flag_, [this]() { return InitNameToIdMap(); }));
9394
auto it = name_to_id_.find(name);
9495
if (it == name_to_id_.end()) return std::nullopt;
9596
return FindFieldById(it->second);
9697
}
97-
ICEBERG_RETURN_UNEXPECTED(InitLowerCaseNameToIdMap());
98+
ICEBERG_RETURN_UNEXPECTED(LazyInitWithCallOnce(
99+
lowercase_name_to_id_flag_, [this]() { return InitLowerCaseNameToIdMap(); }));
98100
auto it = lowercase_name_to_id_.find(StringUtils::ToLower(name));
99101
if (it == lowercase_name_to_id_.end()) return std::nullopt;
100102
return FindFieldById(it->second);
@@ -133,7 +135,8 @@ Status Schema::InitLowerCaseNameToIdMap() const {
133135

134136
Result<std::optional<std::reference_wrapper<const SchemaField>>> Schema::FindFieldById(
135137
int32_t field_id) const {
136-
ICEBERG_RETURN_UNEXPECTED(InitIdToFieldMap());
138+
ICEBERG_RETURN_UNEXPECTED(
139+
LazyInitWithCallOnce(id_to_field_flag_, [this]() { return InitIdToFieldMap(); }));
137140
auto it = id_to_field_.find(field_id);
138141
if (it == id_to_field_.end()) {
139142
return std::nullopt;

src/iceberg/schema.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
/// and any utility functions. See iceberg/type.h and iceberg/field.h as well.
2525

2626
#include <cstdint>
27+
#include <mutex>
2728
#include <optional>
2829
#include <string>
2930
#include <vector>
@@ -78,8 +79,6 @@ class ICEBERG_EXPORT Schema : public StructType {
7879
/// \brief Compare two schemas for equality.
7980
[[nodiscard]] bool Equals(const Schema& other) const;
8081

81-
// TODO(nullccxsy): Address potential concurrency issues in lazy initialization (e.g.,
82-
// use std::call_once)
8382
Status InitIdToFieldMap() const;
8483
Status InitNameToIdMap() const;
8584
Status InitLowerCaseNameToIdMap() const;
@@ -94,6 +93,10 @@ class ICEBERG_EXPORT Schema : public StructType {
9493
/// Mapping from lowercased field name to field id
9594
mutable std::unordered_map<std::string, int32_t, StringHash, std::equal_to<>>
9695
lowercase_name_to_id_;
96+
97+
mutable std::once_flag id_to_field_flag_;
98+
mutable std::once_flag name_to_id_flag_;
99+
mutable std::once_flag lowercase_name_to_id_flag_;
97100
};
98101

99102
} // namespace iceberg

src/iceberg/type.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ std::string StructType::ToString() const {
5050
std::span<const SchemaField> StructType::fields() const { return fields_; }
5151
Result<std::optional<NestedType::SchemaFieldConstRef>> StructType::GetFieldById(
5252
int32_t field_id) const {
53-
ICEBERG_RETURN_UNEXPECTED(InitFieldById());
53+
ICEBERG_RETURN_UNEXPECTED(
54+
LazyInitWithCallOnce(field_by_id_flag_, [this]() { return InitFieldById(); }));
5455
auto it = field_by_id_.find(field_id);
5556
if (it == field_by_id_.end()) return std::nullopt;
5657
return it->second;
@@ -65,14 +66,16 @@ Result<std::optional<NestedType::SchemaFieldConstRef>> StructType::GetFieldByInd
6566
Result<std::optional<NestedType::SchemaFieldConstRef>> StructType::GetFieldByName(
6667
std::string_view name, bool case_sensitive) const {
6768
if (case_sensitive) {
68-
ICEBERG_RETURN_UNEXPECTED(InitFieldByName());
69+
ICEBERG_RETURN_UNEXPECTED(LazyInitWithCallOnce(
70+
field_by_name_flag_, [this]() { return InitFieldByName(); }));
6971
auto it = field_by_name_.find(name);
7072
if (it != field_by_name_.end()) {
7173
return it->second;
7274
}
7375
return std::nullopt;
7476
}
75-
ICEBERG_RETURN_UNEXPECTED(InitFieldByLowerCaseName());
77+
ICEBERG_RETURN_UNEXPECTED(LazyInitWithCallOnce(
78+
field_by_lowercase_name_flag_, [this]() { return InitFieldByLowerCaseName(); }));
7679
auto it = field_by_lowercase_name_.find(StringUtils::ToLower(name));
7780
if (it != field_by_lowercase_name_.end()) {
7881
return it->second;

src/iceberg/type.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <array>
2727
#include <cstdint>
2828
#include <memory>
29+
#include <mutex>
2930
#include <optional>
3031
#include <span>
3132
#include <string>
@@ -39,6 +40,13 @@
3940

4041
namespace iceberg {
4142

43+
template <typename Func>
44+
Status LazyInitWithCallOnce(std::once_flag& flag, Func&& func) {
45+
Status status;
46+
std::call_once(flag, [&status, &func]() { status = func(); });
47+
return status;
48+
}
49+
4250
/// \brief Interface for a data type for a field.
4351
class ICEBERG_EXPORT Type : public iceberg::util::Formattable {
4452
public:
@@ -124,8 +132,7 @@ class ICEBERG_EXPORT StructType : public NestedType {
124132

125133
protected:
126134
bool Equals(const Type& other) const override;
127-
// TODO(nullccxsy): Lazy initialization has concurrency issues, need to add proper
128-
// synchronization mechanism
135+
129136
Status InitFieldById() const;
130137
Status InitFieldByName() const;
131138
Status InitFieldByLowerCaseName() const;
@@ -134,6 +141,10 @@ class ICEBERG_EXPORT StructType : public NestedType {
134141
mutable std::unordered_map<int32_t, SchemaFieldConstRef> field_by_id_;
135142
mutable std::unordered_map<std::string_view, SchemaFieldConstRef> field_by_name_;
136143
mutable std::unordered_map<std::string, SchemaFieldConstRef> field_by_lowercase_name_;
144+
145+
mutable std::once_flag field_by_id_flag_;
146+
mutable std::once_flag field_by_name_flag_;
147+
mutable std::once_flag field_by_lowercase_name_flag_;
137148
};
138149

139150
/// \brief A data type representing a list of values.

test/avro_data_test.cc

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,16 +1195,16 @@ TEST(ExtractDatumFromArrayTest, NullHandling) {
11951195

11961196
struct RoundTripParam {
11971197
std::string name;
1198-
Schema iceberg_schema;
1198+
std::shared_ptr<Schema> iceberg_schema;
11991199
std::string arrow_json;
12001200
};
12011201

12021202
void VerifyRoundTripConversion(const RoundTripParam& test_case) {
12031203
::avro::NodePtr avro_node;
1204-
ASSERT_THAT(ToAvroNodeVisitor{}.Visit(test_case.iceberg_schema, &avro_node), IsOk());
1204+
ASSERT_THAT(ToAvroNodeVisitor{}.Visit(*test_case.iceberg_schema, &avro_node), IsOk());
12051205

12061206
ArrowSchema arrow_c_schema;
1207-
ASSERT_THAT(ToArrowSchema(test_case.iceberg_schema, &arrow_c_schema), IsOk());
1207+
ASSERT_THAT(ToArrowSchema(*test_case.iceberg_schema, &arrow_c_schema), IsOk());
12081208
auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie();
12091209
auto arrow_struct_type = std::make_shared<::arrow::StructType>(arrow_schema->fields());
12101210

@@ -1221,14 +1221,14 @@ void VerifyRoundTripConversion(const RoundTripParam& test_case) {
12211221
}
12221222

12231223
auto projection_result =
1224-
Project(test_case.iceberg_schema, avro_node, /*prune_source=*/false);
1224+
Project(*test_case.iceberg_schema, avro_node, /*prune_source=*/false);
12251225
ASSERT_THAT(projection_result, IsOk());
12261226
auto projection = std::move(projection_result.value());
12271227

12281228
auto builder = ::arrow::MakeBuilder(arrow_struct_type).ValueOrDie();
12291229
for (const auto& datum : extracted_data) {
12301230
ASSERT_THAT(AppendDatumToBuilder(avro_node, datum, projection,
1231-
test_case.iceberg_schema, builder.get()),
1231+
*test_case.iceberg_schema, builder.get()),
12321232
IsOk());
12331233
}
12341234

@@ -1249,7 +1249,7 @@ TEST_P(AvroRoundTripConversionTest, ConvertTypes) {
12491249
const std::vector<RoundTripParam> kRoundTripTestCases = {
12501250
{
12511251
.name = "SimpleStruct",
1252-
.iceberg_schema = Schema({
1252+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
12531253
SchemaField::MakeRequired(1, "id", int32()),
12541254
SchemaField::MakeRequired(2, "name", string()),
12551255
SchemaField::MakeOptional(3, "age", int32()),
@@ -1262,7 +1262,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
12621262
},
12631263
{
12641264
.name = "PrimitiveTypes",
1265-
.iceberg_schema = Schema({
1265+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
12661266
SchemaField::MakeRequired(1, "bool_field", boolean()),
12671267
SchemaField::MakeRequired(2, "int_field", int32()),
12681268
SchemaField::MakeRequired(3, "long_field", int64()),
@@ -1277,7 +1277,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
12771277
},
12781278
{
12791279
.name = "NestedStruct",
1280-
.iceberg_schema = Schema({
1280+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
12811281
SchemaField::MakeRequired(1, "id", int32()),
12821282
SchemaField::MakeRequired(
12831283
2, "person",
@@ -1293,7 +1293,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
12931293
},
12941294
{
12951295
.name = "ListOfIntegers",
1296-
.iceberg_schema = Schema({
1296+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
12971297
SchemaField::MakeRequired(
12981298
1, "numbers",
12991299
std::make_shared<ListType>(
@@ -1307,7 +1307,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
13071307
},
13081308
{
13091309
.name = "MapStringToInt",
1310-
.iceberg_schema = Schema({
1310+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
13111311
SchemaField::MakeRequired(
13121312
1, "scores",
13131313
std::make_shared<MapType>(
@@ -1322,7 +1322,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
13221322
},
13231323
{
13241324
.name = "ComplexNested",
1325-
.iceberg_schema = Schema({
1325+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
13261326
SchemaField::MakeRequired(
13271327
1, "data",
13281328
std::make_shared<StructType>(std::vector<SchemaField>{
@@ -1345,7 +1345,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
13451345
},
13461346
{
13471347
.name = "NullablePrimitives",
1348-
.iceberg_schema = Schema({
1348+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
13491349
SchemaField::MakeOptional(1, "optional_bool", boolean()),
13501350
SchemaField::MakeOptional(2, "optional_int", int32()),
13511351
SchemaField::MakeOptional(3, "optional_long", int64()),
@@ -1361,7 +1361,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
13611361
},
13621362
{
13631363
.name = "NullableNestedStruct",
1364-
.iceberg_schema = Schema({
1364+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
13651365
SchemaField::MakeRequired(1, "id", int32()),
13661366
SchemaField::MakeOptional(
13671367
2, "person",
@@ -1381,7 +1381,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
13811381
},
13821382
{
13831383
.name = "NullableListElements",
1384-
.iceberg_schema = Schema({
1384+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
13851385
SchemaField::MakeRequired(1, "id", int32()),
13861386
SchemaField::MakeOptional(
13871387
2, "numbers",
@@ -1401,7 +1401,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
14011401
},
14021402
{
14031403
.name = "NullableMapValues",
1404-
.iceberg_schema = Schema({
1404+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
14051405
SchemaField::MakeRequired(1, "id", int32()),
14061406
SchemaField::MakeOptional(
14071407
2, "scores",
@@ -1423,7 +1423,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
14231423
},
14241424
{
14251425
.name = "DeeplyNestedWithNulls",
1426-
.iceberg_schema = Schema({
1426+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
14271427
SchemaField::MakeRequired(
14281428
1, "root",
14291429
std::make_shared<StructType>(std::vector<SchemaField>{
@@ -1452,7 +1452,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
14521452
},
14531453
{
14541454
.name = "AllNullsVariations",
1455-
.iceberg_schema = Schema({
1455+
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
14561456
SchemaField::MakeOptional(1, "always_null", string()),
14571457
SchemaField::MakeOptional(2, "sometimes_null", int32()),
14581458
SchemaField::MakeOptional(

test/schema_test.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <format>
2323
#include <memory>
24+
#include <thread>
2425

2526
#include <gmock/gmock.h>
2627
#include <gtest/gtest.h>
@@ -490,3 +491,71 @@ TEST(SchemaTest, NestedDuplicateFieldIdError) {
490491
EXPECT_THAT(result.error().message,
491492
::testing::HasSubstr("Duplicate field id found: 1"));
492493
}
494+
495+
// Thread safety tests for Lazy Init
496+
class SchemaThreadSafetyTest : public ::testing::Test {
497+
protected:
498+
void SetUp() override {
499+
field1_ = std::make_unique<iceberg::SchemaField>(1, "id", iceberg::int32(), true);
500+
field2_ = std::make_unique<iceberg::SchemaField>(2, "name", iceberg::string(), true);
501+
field3_ = std::make_unique<iceberg::SchemaField>(3, "age", iceberg::int32(), true);
502+
schema_ = std::make_unique<iceberg::Schema>(
503+
std::vector<iceberg::SchemaField>{*field1_, *field2_, *field3_}, 100);
504+
}
505+
506+
std::unique_ptr<iceberg::Schema> schema_;
507+
std::unique_ptr<iceberg::SchemaField> field1_;
508+
std::unique_ptr<iceberg::SchemaField> field2_;
509+
std::unique_ptr<iceberg::SchemaField> field3_;
510+
};
511+
512+
TEST_F(SchemaThreadSafetyTest, ConcurrentFindFieldById) {
513+
const int num_threads = 10;
514+
const int iterations_per_thread = 100;
515+
std::vector<std::thread> threads;
516+
517+
for (int i = 0; i < num_threads; ++i) {
518+
threads.emplace_back([this, iterations_per_thread]() {
519+
for (int j = 0; j < iterations_per_thread; ++j) {
520+
ASSERT_THAT(schema_->FindFieldById(1), ::testing::Optional(*field1_));
521+
ASSERT_THAT(schema_->FindFieldById(999), ::testing::Optional(std::nullopt));
522+
}
523+
});
524+
}
525+
526+
for (auto& thread : threads) {
527+
thread.join();
528+
}
529+
}
530+
531+
TEST_F(SchemaThreadSafetyTest, MixedConcurrentOperations) {
532+
const int num_threads = 8;
533+
const int iterations_per_thread = 50;
534+
std::vector<std::thread> threads;
535+
536+
for (int i = 0; i < num_threads; ++i) {
537+
threads.emplace_back([this, iterations_per_thread, i]() {
538+
for (int j = 0; j < iterations_per_thread; ++j) {
539+
if (i % 4 == 0) {
540+
ASSERT_THAT(schema_->FindFieldById(1), ::testing::Optional(*field1_));
541+
} else if (i % 4 == 1) {
542+
ASSERT_THAT(schema_->FindFieldByName("name", true),
543+
::testing::Optional(*field2_));
544+
} else if (i % 4 == 2) {
545+
ASSERT_THAT(schema_->FindFieldByName("AGE", false),
546+
::testing::Optional(*field3_));
547+
} else {
548+
ASSERT_THAT(schema_->FindFieldById(2), ::testing::Optional(*field2_));
549+
ASSERT_THAT(schema_->FindFieldByName("id", true),
550+
::testing::Optional(*field1_));
551+
ASSERT_THAT(schema_->FindFieldByName("age", false),
552+
::testing::Optional(*field3_));
553+
}
554+
}
555+
});
556+
}
557+
558+
for (auto& thread : threads) {
559+
thread.join();
560+
}
561+
}

0 commit comments

Comments
 (0)