Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/iceberg/manifest_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ struct ICEBERG_EXPORT ManifestFile {
507, "partitions",
std::make_shared<ListType>(SchemaField::MakeRequired(
508, std::string(ListType::kElementName),
std::make_shared<StructType>(PartitionFieldSummary::Type()))),
struct_(
{PartitionFieldSummary::kContainsNull, PartitionFieldSummary::kContainsNaN,
PartitionFieldSummary::kLowerBound, PartitionFieldSummary::kUpperBound}))),
"Summary for each partition");
inline static const SchemaField kKeyMetadata = SchemaField::MakeOptional(
519, "key_metadata", iceberg::binary(), "Encryption key metadata blob");
Expand Down
9 changes: 6 additions & 3 deletions src/iceberg/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ bool Schema::Equals(const Schema& other) const {
Result<std::optional<std::reference_wrapper<const SchemaField>>> Schema::FindFieldByName(
std::string_view name, bool case_sensitive) const {
if (case_sensitive) {
ICEBERG_RETURN_UNEXPECTED(InitNameToIdMap());
ICEBERG_RETURN_UNEXPECTED(
LazyInitWithCallOnce(name_to_id_flag_, [this]() { return InitNameToIdMap(); }));
auto it = name_to_id_.find(name);
if (it == name_to_id_.end()) return std::nullopt;
return FindFieldById(it->second);
}
ICEBERG_RETURN_UNEXPECTED(InitLowerCaseNameToIdMap());
ICEBERG_RETURN_UNEXPECTED(LazyInitWithCallOnce(
lowercase_name_to_id_flag_, [this]() { return InitLowerCaseNameToIdMap(); }));
auto it = lowercase_name_to_id_.find(StringUtils::ToLower(name));
if (it == lowercase_name_to_id_.end()) return std::nullopt;
return FindFieldById(it->second);
Expand Down Expand Up @@ -133,7 +135,8 @@ Status Schema::InitLowerCaseNameToIdMap() const {

Result<std::optional<std::reference_wrapper<const SchemaField>>> Schema::FindFieldById(
int32_t field_id) const {
ICEBERG_RETURN_UNEXPECTED(InitIdToFieldMap());
ICEBERG_RETURN_UNEXPECTED(
LazyInitWithCallOnce(id_to_field_flag_, [this]() { return InitIdToFieldMap(); }));
auto it = id_to_field_.find(field_id);
if (it == id_to_field_.end()) {
return std::nullopt;
Expand Down
7 changes: 5 additions & 2 deletions src/iceberg/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
/// and any utility functions. See iceberg/type.h and iceberg/field.h as well.

#include <cstdint>
#include <mutex>
#include <optional>
#include <string>
#include <vector>
Expand Down Expand Up @@ -78,8 +79,6 @@ class ICEBERG_EXPORT Schema : public StructType {
/// \brief Compare two schemas for equality.
[[nodiscard]] bool Equals(const Schema& other) const;

// TODO(nullccxsy): Address potential concurrency issues in lazy initialization (e.g.,
// use std::call_once)
Status InitIdToFieldMap() const;
Status InitNameToIdMap() const;
Status InitLowerCaseNameToIdMap() const;
Expand All @@ -94,6 +93,10 @@ class ICEBERG_EXPORT Schema : public StructType {
/// Mapping from lowercased field name to field id
mutable std::unordered_map<std::string, int32_t, StringHash, std::equal_to<>>
lowercase_name_to_id_;

mutable std::once_flag id_to_field_flag_;
mutable std::once_flag name_to_id_flag_;
mutable std::once_flag lowercase_name_to_id_flag_;
};

} // namespace iceberg
9 changes: 6 additions & 3 deletions src/iceberg/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ std::string StructType::ToString() const {
std::span<const SchemaField> StructType::fields() const { return fields_; }
Result<std::optional<NestedType::SchemaFieldConstRef>> StructType::GetFieldById(
int32_t field_id) const {
ICEBERG_RETURN_UNEXPECTED(InitFieldById());
ICEBERG_RETURN_UNEXPECTED(
LazyInitWithCallOnce(field_by_id_flag_, [this]() { return InitFieldById(); }));
auto it = field_by_id_.find(field_id);
if (it == field_by_id_.end()) return std::nullopt;
return it->second;
Expand All @@ -65,14 +66,16 @@ Result<std::optional<NestedType::SchemaFieldConstRef>> StructType::GetFieldByInd
Result<std::optional<NestedType::SchemaFieldConstRef>> StructType::GetFieldByName(
std::string_view name, bool case_sensitive) const {
if (case_sensitive) {
ICEBERG_RETURN_UNEXPECTED(InitFieldByName());
ICEBERG_RETURN_UNEXPECTED(LazyInitWithCallOnce(
field_by_name_flag_, [this]() { return InitFieldByName(); }));
auto it = field_by_name_.find(name);
if (it != field_by_name_.end()) {
return it->second;
}
return std::nullopt;
}
ICEBERG_RETURN_UNEXPECTED(InitFieldByLowerCaseName());
ICEBERG_RETURN_UNEXPECTED(LazyInitWithCallOnce(
field_by_lowercase_name_flag_, [this]() { return InitFieldByLowerCaseName(); }));
auto it = field_by_lowercase_name_.find(StringUtils::ToLower(name));
if (it != field_by_lowercase_name_.end()) {
return it->second;
Expand Down
15 changes: 13 additions & 2 deletions src/iceberg/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <array>
#include <cstdint>
#include <memory>
#include <mutex>
#include <optional>
#include <span>
#include <string>
Expand All @@ -39,6 +40,13 @@

namespace iceberg {

template <typename Func>
Status LazyInitWithCallOnce(std::once_flag& flag, Func&& func) {
Status status;
std::call_once(flag, [&status, &func]() { status = func(); });
return status;
}

/// \brief Interface for a data type for a field.
class ICEBERG_EXPORT Type : public iceberg::util::Formattable {
public:
Expand Down Expand Up @@ -124,8 +132,7 @@ class ICEBERG_EXPORT StructType : public NestedType {

protected:
bool Equals(const Type& other) const override;
// TODO(nullccxsy): Lazy initialization has concurrency issues, need to add proper
// synchronization mechanism

Status InitFieldById() const;
Status InitFieldByName() const;
Status InitFieldByLowerCaseName() const;
Expand All @@ -134,6 +141,10 @@ class ICEBERG_EXPORT StructType : public NestedType {
mutable std::unordered_map<int32_t, SchemaFieldConstRef> field_by_id_;
mutable std::unordered_map<std::string_view, SchemaFieldConstRef> field_by_name_;
mutable std::unordered_map<std::string, SchemaFieldConstRef> field_by_lowercase_name_;

mutable std::once_flag field_by_id_flag_;
mutable std::once_flag field_by_name_flag_;
mutable std::once_flag field_by_lowercase_name_flag_;
};

/// \brief A data type representing a list of values.
Expand Down
34 changes: 17 additions & 17 deletions test/avro_data_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1195,16 +1195,16 @@ TEST(ExtractDatumFromArrayTest, NullHandling) {

struct RoundTripParam {
std::string name;
Schema iceberg_schema;
std::shared_ptr<Schema> iceberg_schema;
std::string arrow_json;
};

void VerifyRoundTripConversion(const RoundTripParam& test_case) {
::avro::NodePtr avro_node;
ASSERT_THAT(ToAvroNodeVisitor{}.Visit(test_case.iceberg_schema, &avro_node), IsOk());
ASSERT_THAT(ToAvroNodeVisitor{}.Visit(*test_case.iceberg_schema, &avro_node), IsOk());

ArrowSchema arrow_c_schema;
ASSERT_THAT(ToArrowSchema(test_case.iceberg_schema, &arrow_c_schema), IsOk());
ASSERT_THAT(ToArrowSchema(*test_case.iceberg_schema, &arrow_c_schema), IsOk());
auto arrow_schema = ::arrow::ImportSchema(&arrow_c_schema).ValueOrDie();
auto arrow_struct_type = std::make_shared<::arrow::StructType>(arrow_schema->fields());

Expand All @@ -1221,14 +1221,14 @@ void VerifyRoundTripConversion(const RoundTripParam& test_case) {
}

auto projection_result =
Project(test_case.iceberg_schema, avro_node, /*prune_source=*/false);
Project(*test_case.iceberg_schema, avro_node, /*prune_source=*/false);
ASSERT_THAT(projection_result, IsOk());
auto projection = std::move(projection_result.value());

auto builder = ::arrow::MakeBuilder(arrow_struct_type).ValueOrDie();
for (const auto& datum : extracted_data) {
ASSERT_THAT(AppendDatumToBuilder(avro_node, datum, projection,
test_case.iceberg_schema, builder.get()),
*test_case.iceberg_schema, builder.get()),
IsOk());
}

Expand All @@ -1249,7 +1249,7 @@ TEST_P(AvroRoundTripConversionTest, ConvertTypes) {
const std::vector<RoundTripParam> kRoundTripTestCases = {
{
.name = "SimpleStruct",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeRequired(1, "id", int32()),
SchemaField::MakeRequired(2, "name", string()),
SchemaField::MakeOptional(3, "age", int32()),
Expand All @@ -1262,7 +1262,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
},
{
.name = "PrimitiveTypes",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeRequired(1, "bool_field", boolean()),
SchemaField::MakeRequired(2, "int_field", int32()),
SchemaField::MakeRequired(3, "long_field", int64()),
Expand All @@ -1277,7 +1277,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
},
{
.name = "NestedStruct",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeRequired(1, "id", int32()),
SchemaField::MakeRequired(
2, "person",
Expand All @@ -1293,7 +1293,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
},
{
.name = "ListOfIntegers",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeRequired(
1, "numbers",
std::make_shared<ListType>(
Expand All @@ -1307,7 +1307,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
},
{
.name = "MapStringToInt",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeRequired(
1, "scores",
std::make_shared<MapType>(
Expand All @@ -1322,7 +1322,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
},
{
.name = "ComplexNested",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeRequired(
1, "data",
std::make_shared<StructType>(std::vector<SchemaField>{
Expand All @@ -1345,7 +1345,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
},
{
.name = "NullablePrimitives",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeOptional(1, "optional_bool", boolean()),
SchemaField::MakeOptional(2, "optional_int", int32()),
SchemaField::MakeOptional(3, "optional_long", int64()),
Expand All @@ -1361,7 +1361,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
},
{
.name = "NullableNestedStruct",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeRequired(1, "id", int32()),
SchemaField::MakeOptional(
2, "person",
Expand All @@ -1381,7 +1381,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
},
{
.name = "NullableListElements",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeRequired(1, "id", int32()),
SchemaField::MakeOptional(
2, "numbers",
Expand All @@ -1401,7 +1401,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
},
{
.name = "NullableMapValues",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeRequired(1, "id", int32()),
SchemaField::MakeOptional(
2, "scores",
Expand All @@ -1423,7 +1423,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
},
{
.name = "DeeplyNestedWithNulls",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeRequired(
1, "root",
std::make_shared<StructType>(std::vector<SchemaField>{
Expand Down Expand Up @@ -1452,7 +1452,7 @@ const std::vector<RoundTripParam> kRoundTripTestCases = {
},
{
.name = "AllNullsVariations",
.iceberg_schema = Schema({
.iceberg_schema = std::make_shared<Schema>(std::vector<SchemaField>{
SchemaField::MakeOptional(1, "always_null", string()),
SchemaField::MakeOptional(2, "sometimes_null", int32()),
SchemaField::MakeOptional(
Expand Down
69 changes: 69 additions & 0 deletions test/schema_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <format>
#include <memory>
#include <thread>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
Expand Down Expand Up @@ -490,3 +491,71 @@ TEST(SchemaTest, NestedDuplicateFieldIdError) {
EXPECT_THAT(result.error().message,
::testing::HasSubstr("Duplicate field id found: 1"));
}

// Thread safety tests for Lazy Init
class SchemaThreadSafetyTest : public ::testing::Test {
protected:
void SetUp() override {
field1_ = std::make_unique<iceberg::SchemaField>(1, "id", iceberg::int32(), true);
field2_ = std::make_unique<iceberg::SchemaField>(2, "name", iceberg::string(), true);
field3_ = std::make_unique<iceberg::SchemaField>(3, "age", iceberg::int32(), true);
schema_ = std::make_unique<iceberg::Schema>(
std::vector<iceberg::SchemaField>{*field1_, *field2_, *field3_}, 100);
}

std::unique_ptr<iceberg::Schema> schema_;
std::unique_ptr<iceberg::SchemaField> field1_;
std::unique_ptr<iceberg::SchemaField> field2_;
std::unique_ptr<iceberg::SchemaField> field3_;
};

TEST_F(SchemaThreadSafetyTest, ConcurrentFindFieldById) {
const int num_threads = 10;
const int iterations_per_thread = 100;
std::vector<std::thread> threads;

for (int i = 0; i < num_threads; ++i) {
threads.emplace_back([this, iterations_per_thread]() {
for (int j = 0; j < iterations_per_thread; ++j) {
ASSERT_THAT(schema_->FindFieldById(1), ::testing::Optional(*field1_));
ASSERT_THAT(schema_->FindFieldById(999), ::testing::Optional(std::nullopt));
}
});
}

for (auto& thread : threads) {
thread.join();
}
}

TEST_F(SchemaThreadSafetyTest, MixedConcurrentOperations) {
const int num_threads = 8;
const int iterations_per_thread = 50;
std::vector<std::thread> threads;

for (int i = 0; i < num_threads; ++i) {
threads.emplace_back([this, iterations_per_thread, i]() {
for (int j = 0; j < iterations_per_thread; ++j) {
if (i % 4 == 0) {
ASSERT_THAT(schema_->FindFieldById(1), ::testing::Optional(*field1_));
} else if (i % 4 == 1) {
ASSERT_THAT(schema_->FindFieldByName("name", true),
::testing::Optional(*field2_));
} else if (i % 4 == 2) {
ASSERT_THAT(schema_->FindFieldByName("AGE", false),
::testing::Optional(*field3_));
} else {
ASSERT_THAT(schema_->FindFieldById(2), ::testing::Optional(*field2_));
ASSERT_THAT(schema_->FindFieldByName("id", true),
::testing::Optional(*field1_));
ASSERT_THAT(schema_->FindFieldByName("age", false),
::testing::Optional(*field3_));
}
}
});
}

for (auto& thread : threads) {
thread.join();
}
}
Loading
Loading