diff --git a/src/iceberg/manifest_list.h b/src/iceberg/manifest_list.h index 66433da39..d6e7f1dd2 100644 --- a/src/iceberg/manifest_list.h +++ b/src/iceberg/manifest_list.h @@ -185,7 +185,9 @@ struct ICEBERG_EXPORT ManifestFile { 507, "partitions", std::make_shared(SchemaField::MakeRequired( 508, std::string(ListType::kElementName), - std::make_shared(PartitionFieldSummary::Type()))), + struct_( + {PartitionFieldSummary::kContainsNull, PartitionFieldSummary::kContainsNaN, + PartitionFieldSummary::kLowerBound, PartitionFieldSummary::kUpperBound}))), "Summary for each partition"); inline static const SchemaField kKeyMetadata = SchemaField::MakeOptional( 519, "key_metadata", iceberg::binary(), "Encryption key metadata blob"); diff --git a/src/iceberg/schema.cc b/src/iceberg/schema.cc index 0b67dfb0d..1df20c60b 100644 --- a/src/iceberg/schema.cc +++ b/src/iceberg/schema.cc @@ -89,12 +89,14 @@ bool Schema::Equals(const Schema& other) const { Result>> Schema::FindFieldByName( std::string_view name, bool case_sensitive) const { if (case_sensitive) { - ICEBERG_RETURN_UNEXPECTED(InitNameToIdMap()); + ICEBERG_RETURN_UNEXPECTED( + LazyInitWithCallOnce(name_to_id_flag_, [this]() { return InitNameToIdMap(); })); auto it = name_to_id_.find(name); if (it == name_to_id_.end()) return std::nullopt; return FindFieldById(it->second); } - ICEBERG_RETURN_UNEXPECTED(InitLowerCaseNameToIdMap()); + ICEBERG_RETURN_UNEXPECTED(LazyInitWithCallOnce( + lowercase_name_to_id_flag_, [this]() { return InitLowerCaseNameToIdMap(); })); auto it = lowercase_name_to_id_.find(StringUtils::ToLower(name)); if (it == lowercase_name_to_id_.end()) return std::nullopt; return FindFieldById(it->second); @@ -133,7 +135,8 @@ Status Schema::InitLowerCaseNameToIdMap() const { Result>> Schema::FindFieldById( int32_t field_id) const { - ICEBERG_RETURN_UNEXPECTED(InitIdToFieldMap()); + ICEBERG_RETURN_UNEXPECTED( + LazyInitWithCallOnce(id_to_field_flag_, [this]() { return InitIdToFieldMap(); })); auto it = id_to_field_.find(field_id); if (it == id_to_field_.end()) { return std::nullopt; diff --git a/src/iceberg/schema.h b/src/iceberg/schema.h index 1de829c80..260d9d342 100644 --- a/src/iceberg/schema.h +++ b/src/iceberg/schema.h @@ -24,6 +24,7 @@ /// and any utility functions. See iceberg/type.h and iceberg/field.h as well. #include +#include #include #include #include @@ -78,8 +79,6 @@ class ICEBERG_EXPORT Schema : public StructType { /// \brief Compare two schemas for equality. [[nodiscard]] bool Equals(const Schema& other) const; - // TODO(nullccxsy): Address potential concurrency issues in lazy initialization (e.g., - // use std::call_once) Status InitIdToFieldMap() const; Status InitNameToIdMap() const; Status InitLowerCaseNameToIdMap() const; @@ -94,6 +93,10 @@ class ICEBERG_EXPORT Schema : public StructType { /// Mapping from lowercased field name to field id mutable std::unordered_map> lowercase_name_to_id_; + + mutable std::once_flag id_to_field_flag_; + mutable std::once_flag name_to_id_flag_; + mutable std::once_flag lowercase_name_to_id_flag_; }; } // namespace iceberg diff --git a/src/iceberg/type.cc b/src/iceberg/type.cc index b435bb329..8d230d7d0 100644 --- a/src/iceberg/type.cc +++ b/src/iceberg/type.cc @@ -50,7 +50,8 @@ std::string StructType::ToString() const { std::span StructType::fields() const { return fields_; } Result> StructType::GetFieldById( int32_t field_id) const { - ICEBERG_RETURN_UNEXPECTED(InitFieldById()); + ICEBERG_RETURN_UNEXPECTED( + LazyInitWithCallOnce(field_by_id_flag_, [this]() { return InitFieldById(); })); auto it = field_by_id_.find(field_id); if (it == field_by_id_.end()) return std::nullopt; return it->second; @@ -65,14 +66,16 @@ Result> StructType::GetFieldByInd Result> StructType::GetFieldByName( std::string_view name, bool case_sensitive) const { if (case_sensitive) { - ICEBERG_RETURN_UNEXPECTED(InitFieldByName()); + ICEBERG_RETURN_UNEXPECTED(LazyInitWithCallOnce( + field_by_name_flag_, [this]() { return InitFieldByName(); })); auto it = field_by_name_.find(name); if (it != field_by_name_.end()) { return it->second; } return std::nullopt; } - ICEBERG_RETURN_UNEXPECTED(InitFieldByLowerCaseName()); + ICEBERG_RETURN_UNEXPECTED(LazyInitWithCallOnce( + field_by_lowercase_name_flag_, [this]() { return InitFieldByLowerCaseName(); })); auto it = field_by_lowercase_name_.find(StringUtils::ToLower(name)); if (it != field_by_lowercase_name_.end()) { return it->second; diff --git a/src/iceberg/type.h b/src/iceberg/type.h index d1d193409..01c911dd8 100644 --- a/src/iceberg/type.h +++ b/src/iceberg/type.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -39,6 +40,13 @@ namespace iceberg { +template +Status LazyInitWithCallOnce(std::once_flag& flag, Func&& func) { + Status status; + std::call_once(flag, [&status, &func]() { status = func(); }); + return status; +} + /// \brief Interface for a data type for a field. class ICEBERG_EXPORT Type : public iceberg::util::Formattable { public: @@ -124,8 +132,7 @@ class ICEBERG_EXPORT StructType : public NestedType { protected: bool Equals(const Type& other) const override; - // TODO(nullccxsy): Lazy initialization has concurrency issues, need to add proper - // synchronization mechanism + Status InitFieldById() const; Status InitFieldByName() const; Status InitFieldByLowerCaseName() const; @@ -134,6 +141,10 @@ class ICEBERG_EXPORT StructType : public NestedType { mutable std::unordered_map field_by_id_; mutable std::unordered_map field_by_name_; mutable std::unordered_map field_by_lowercase_name_; + + mutable std::once_flag field_by_id_flag_; + mutable std::once_flag field_by_name_flag_; + mutable std::once_flag field_by_lowercase_name_flag_; }; /// \brief A data type representing a list of values. diff --git a/test/avro_data_test.cc b/test/avro_data_test.cc index b5cc1c5ae..2797f9bf6 100644 --- a/test/avro_data_test.cc +++ b/test/avro_data_test.cc @@ -1195,16 +1195,16 @@ TEST(ExtractDatumFromArrayTest, NullHandling) { struct RoundTripParam { std::string name; - Schema iceberg_schema; + std::shared_ptr iceberg_schema; std::string arrow_json; }; void VerifyRoundTripConversion(const RoundTripParam& test_case) { ::avro::NodePtr avro_node; - ASSERT_THAT(ToAvroNodeVisitor{}.Visit(test_case.iceberg_schema, &avro_node), IsOk()); + ASSERT_THAT(ToAvroNodeVisitor{}.Visit(*test_case.iceberg_schema, &avro_node), IsOk()); ArrowSchema arrow_c_schema; - ASSERT_THAT(ToArrowSchema(test_case.iceberg_schema, &arrow_c_schema), IsOk()); + ASSERT_THAT(ToArrowSchema(*test_case.iceberg_schema, &arrow_c_schema), IsOk()); auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie(); auto arrow_struct_type = std::make_shared<::arrow::StructType>(arrow_schema->fields()); @@ -1221,14 +1221,14 @@ void VerifyRoundTripConversion(const RoundTripParam& test_case) { } auto projection_result = - Project(test_case.iceberg_schema, avro_node, /*prune_source=*/false); + Project(*test_case.iceberg_schema, avro_node, /*prune_source=*/false); ASSERT_THAT(projection_result, IsOk()); auto projection = std::move(projection_result.value()); auto builder = ::arrow::MakeBuilder(arrow_struct_type).ValueOrDie(); for (const auto& datum : extracted_data) { ASSERT_THAT(AppendDatumToBuilder(avro_node, datum, projection, - test_case.iceberg_schema, builder.get()), + *test_case.iceberg_schema, builder.get()), IsOk()); } @@ -1249,7 +1249,7 @@ TEST_P(AvroRoundTripConversionTest, ConvertTypes) { const std::vector kRoundTripTestCases = { { .name = "SimpleStruct", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeRequired(1, "id", int32()), SchemaField::MakeRequired(2, "name", string()), SchemaField::MakeOptional(3, "age", int32()), @@ -1262,7 +1262,7 @@ const std::vector kRoundTripTestCases = { }, { .name = "PrimitiveTypes", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeRequired(1, "bool_field", boolean()), SchemaField::MakeRequired(2, "int_field", int32()), SchemaField::MakeRequired(3, "long_field", int64()), @@ -1277,7 +1277,7 @@ const std::vector kRoundTripTestCases = { }, { .name = "NestedStruct", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeRequired(1, "id", int32()), SchemaField::MakeRequired( 2, "person", @@ -1293,7 +1293,7 @@ const std::vector kRoundTripTestCases = { }, { .name = "ListOfIntegers", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeRequired( 1, "numbers", std::make_shared( @@ -1307,7 +1307,7 @@ const std::vector kRoundTripTestCases = { }, { .name = "MapStringToInt", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeRequired( 1, "scores", std::make_shared( @@ -1322,7 +1322,7 @@ const std::vector kRoundTripTestCases = { }, { .name = "ComplexNested", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeRequired( 1, "data", std::make_shared(std::vector{ @@ -1345,7 +1345,7 @@ const std::vector kRoundTripTestCases = { }, { .name = "NullablePrimitives", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeOptional(1, "optional_bool", boolean()), SchemaField::MakeOptional(2, "optional_int", int32()), SchemaField::MakeOptional(3, "optional_long", int64()), @@ -1361,7 +1361,7 @@ const std::vector kRoundTripTestCases = { }, { .name = "NullableNestedStruct", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeRequired(1, "id", int32()), SchemaField::MakeOptional( 2, "person", @@ -1381,7 +1381,7 @@ const std::vector kRoundTripTestCases = { }, { .name = "NullableListElements", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeRequired(1, "id", int32()), SchemaField::MakeOptional( 2, "numbers", @@ -1401,7 +1401,7 @@ const std::vector kRoundTripTestCases = { }, { .name = "NullableMapValues", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeRequired(1, "id", int32()), SchemaField::MakeOptional( 2, "scores", @@ -1423,7 +1423,7 @@ const std::vector kRoundTripTestCases = { }, { .name = "DeeplyNestedWithNulls", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeRequired( 1, "root", std::make_shared(std::vector{ @@ -1452,7 +1452,7 @@ const std::vector kRoundTripTestCases = { }, { .name = "AllNullsVariations", - .iceberg_schema = Schema({ + .iceberg_schema = std::make_shared(std::vector{ SchemaField::MakeOptional(1, "always_null", string()), SchemaField::MakeOptional(2, "sometimes_null", int32()), SchemaField::MakeOptional( diff --git a/test/schema_test.cc b/test/schema_test.cc index 272c6e75a..b01ffe9ba 100644 --- a/test/schema_test.cc +++ b/test/schema_test.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -490,3 +491,71 @@ TEST(SchemaTest, NestedDuplicateFieldIdError) { EXPECT_THAT(result.error().message, ::testing::HasSubstr("Duplicate field id found: 1")); } + +// Thread safety tests for Lazy Init +class SchemaThreadSafetyTest : public ::testing::Test { + protected: + void SetUp() override { + field1_ = std::make_unique(1, "id", iceberg::int32(), true); + field2_ = std::make_unique(2, "name", iceberg::string(), true); + field3_ = std::make_unique(3, "age", iceberg::int32(), true); + schema_ = std::make_unique( + std::vector{*field1_, *field2_, *field3_}, 100); + } + + std::unique_ptr schema_; + std::unique_ptr field1_; + std::unique_ptr field2_; + std::unique_ptr field3_; +}; + +TEST_F(SchemaThreadSafetyTest, ConcurrentFindFieldById) { + const int num_threads = 10; + const int iterations_per_thread = 100; + std::vector threads; + + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back([this, iterations_per_thread]() { + for (int j = 0; j < iterations_per_thread; ++j) { + ASSERT_THAT(schema_->FindFieldById(1), ::testing::Optional(*field1_)); + ASSERT_THAT(schema_->FindFieldById(999), ::testing::Optional(std::nullopt)); + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } +} + +TEST_F(SchemaThreadSafetyTest, MixedConcurrentOperations) { + const int num_threads = 8; + const int iterations_per_thread = 50; + std::vector threads; + + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back([this, iterations_per_thread, i]() { + for (int j = 0; j < iterations_per_thread; ++j) { + if (i % 4 == 0) { + ASSERT_THAT(schema_->FindFieldById(1), ::testing::Optional(*field1_)); + } else if (i % 4 == 1) { + ASSERT_THAT(schema_->FindFieldByName("name", true), + ::testing::Optional(*field2_)); + } else if (i % 4 == 2) { + ASSERT_THAT(schema_->FindFieldByName("AGE", false), + ::testing::Optional(*field3_)); + } else { + ASSERT_THAT(schema_->FindFieldById(2), ::testing::Optional(*field2_)); + ASSERT_THAT(schema_->FindFieldByName("id", true), + ::testing::Optional(*field1_)); + ASSERT_THAT(schema_->FindFieldByName("age", false), + ::testing::Optional(*field3_)); + } + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } +} diff --git a/test/type_test.cc b/test/type_test.cc index 9963ab364..4fd8c46c1 100644 --- a/test/type_test.cc +++ b/test/type_test.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -511,3 +512,95 @@ TEST(TypeTest, StructDuplicateLowerCaseName) { iceberg::HasErrorMessage( "Duplicate lowercase field name found: foo (prev id: 1, curr id: 2)")); } + +// Thread safety tests for StructType Lazy Init +class StructTypeThreadSafetyTest : public ::testing::Test { + protected: + void SetUp() override { + field1_ = std::make_unique(1, "id", iceberg::int32(), true); + field2_ = std::make_unique(2, "name", iceberg::string(), true); + field3_ = std::make_unique(3, "age", iceberg::int32(), true); + + struct_type_ = std::make_unique( + std::vector{*field1_, *field2_, *field3_}); + } + + std::unique_ptr struct_type_; + std::unique_ptr field1_; + std::unique_ptr field2_; + std::unique_ptr field3_; +}; + +TEST_F(StructTypeThreadSafetyTest, ConcurrentGetFieldById) { + const int num_threads = 10; + const int iterations_per_thread = 100; + std::vector threads; + + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back([this, iterations_per_thread]() { + for (int j = 0; j < iterations_per_thread; ++j) { + ASSERT_THAT(struct_type_->GetFieldById(1), ::testing::Optional(*field1_)); + ASSERT_THAT(struct_type_->GetFieldById(999), ::testing::Optional(std::nullopt)); + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } +} + +TEST_F(StructTypeThreadSafetyTest, ConcurrentGetFieldByName) { + const int num_threads = 10; + const int iterations_per_thread = 100; + std::vector threads; + + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back([this, iterations_per_thread]() { + for (int j = 0; j < iterations_per_thread; ++j) { + ASSERT_THAT(struct_type_->GetFieldByName("id", true), + ::testing::Optional(*field1_)); + ASSERT_THAT(struct_type_->GetFieldByName("NAME", false), + ::testing::Optional(*field2_)); + ASSERT_THAT(struct_type_->GetFieldByName("noexist", false), + ::testing::Optional(std::nullopt)); + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } +} + +TEST_F(StructTypeThreadSafetyTest, MixedConcurrentOperations) { + const int num_threads = 8; + const int iterations_per_thread = 50; + std::vector threads; + + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back([this, iterations_per_thread, i]() { + for (int j = 0; j < iterations_per_thread; ++j) { + if (i % 4 == 0) { + ASSERT_THAT(struct_type_->GetFieldById(1), ::testing::Optional(*field1_)); + } else if (i % 4 == 1) { + ASSERT_THAT(struct_type_->GetFieldByName("name", true), + ::testing::Optional(*field2_)); + } else if (i % 4 == 2) { + ASSERT_THAT(struct_type_->GetFieldByName("AGE", false), + ::testing::Optional(*field3_)); + } else { + ASSERT_THAT(struct_type_->GetFieldById(2), ::testing::Optional(*field2_)); + ASSERT_THAT(struct_type_->GetFieldByName("id", true), + ::testing::Optional(*field1_)); + ASSERT_THAT(struct_type_->GetFieldByName("age", false), + ::testing::Optional(*field3_)); + } + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } +}