|
26 | 26 | #include <gmock/gmock.h> |
27 | 27 | #include <gtest/gtest.h> |
28 | 28 |
|
| 29 | +#include "gtest/gtest.h" |
29 | 30 | #include "iceberg/result.h" |
30 | 31 | #include "iceberg/schema_field.h" |
31 | 32 | #include "iceberg/util/formatter.h" // IWYU pragma: keep |
@@ -762,7 +763,7 @@ TEST_P(ProjectParamTest, ProjectFields) { |
762 | 763 |
|
763 | 764 | if (param.should_succeed) { |
764 | 765 | ASSERT_TRUE(result.has_value()); |
765 | | - ASSERT_THAT(*result.value(), *param.expected_schema()); |
| 766 | + ASSERT_EQ(*result.value(), *param.expected_schema()); |
766 | 767 | } else { |
767 | 768 | ASSERT_FALSE(result.has_value()); |
768 | 769 | ASSERT_THAT(result, iceberg::IsError(iceberg::ErrorKind::kInvalidArgument)); |
@@ -942,3 +943,70 @@ INSTANTIATE_TEST_SUITE_P( |
942 | 943 | .selected_ids = {999}, // Select non-existent field |
943 | 944 | .expected_schema = []() { return MakeSchema(); }, |
944 | 945 | .should_succeed = true})); |
| 946 | + |
| 947 | +class SchemaThreadSafetyTest : public ::testing::Test { |
| 948 | + protected: |
| 949 | + void SetUp() override { |
| 950 | + field1_ = std::make_unique<iceberg::SchemaField>(1, "id", iceberg::int32(), true); |
| 951 | + field2_ = std::make_unique<iceberg::SchemaField>(2, "name", iceberg::string(), true); |
| 952 | + field3_ = std::make_unique<iceberg::SchemaField>(3, "age", iceberg::int32(), true); |
| 953 | + schema_ = std::make_unique<iceberg::Schema>( |
| 954 | + std::vector<iceberg::SchemaField>{*field1_, *field2_, *field3_}, 100); |
| 955 | + } |
| 956 | + |
| 957 | + std::unique_ptr<iceberg::Schema> schema_; |
| 958 | + std::unique_ptr<iceberg::SchemaField> field1_; |
| 959 | + std::unique_ptr<iceberg::SchemaField> field2_; |
| 960 | + std::unique_ptr<iceberg::SchemaField> field3_; |
| 961 | +}; |
| 962 | + |
| 963 | +TEST_F(SchemaThreadSafetyTest, ConcurrentFindFieldById) { |
| 964 | + const int num_threads = 10; |
| 965 | + const int iterations_per_thread = 100; |
| 966 | + std::vector<std::thread> threads; |
| 967 | + |
| 968 | + for (int i = 0; i < num_threads; ++i) { |
| 969 | + threads.emplace_back([this, iterations_per_thread]() { |
| 970 | + for (int j = 0; j < iterations_per_thread; ++j) { |
| 971 | + ASSERT_THAT(schema_->FindFieldById(1), ::testing::Optional(*field1_)); |
| 972 | + ASSERT_THAT(schema_->FindFieldById(999), ::testing::Optional(std::nullopt)); |
| 973 | + } |
| 974 | + }); |
| 975 | + } |
| 976 | + |
| 977 | + for (auto& thread : threads) { |
| 978 | + thread.join(); |
| 979 | + } |
| 980 | +} |
| 981 | + |
| 982 | +TEST_F(SchemaThreadSafetyTest, MixedConcurrentOperations) { |
| 983 | + const int num_threads = 8; |
| 984 | + const int iterations_per_thread = 50; |
| 985 | + std::vector<std::thread> threads; |
| 986 | + |
| 987 | + for (int i = 0; i < num_threads; ++i) { |
| 988 | + threads.emplace_back([this, iterations_per_thread, i]() { |
| 989 | + for (int j = 0; j < iterations_per_thread; ++j) { |
| 990 | + if (i % 4 == 0) { |
| 991 | + ASSERT_THAT(schema_->FindFieldById(1), ::testing::Optional(*field1_)); |
| 992 | + } else if (i % 4 == 1) { |
| 993 | + ASSERT_THAT(schema_->FindFieldByName("name", true), |
| 994 | + ::testing::Optional(*field2_)); |
| 995 | + } else if (i % 4 == 2) { |
| 996 | + ASSERT_THAT(schema_->FindFieldByName("AGE", false), |
| 997 | + ::testing::Optional(*field3_)); |
| 998 | + } else { |
| 999 | + ASSERT_THAT(schema_->FindFieldById(2), ::testing::Optional(*field2_)); |
| 1000 | + ASSERT_THAT(schema_->FindFieldByName("id", true), |
| 1001 | + ::testing::Optional(*field1_)); |
| 1002 | + ASSERT_THAT(schema_->FindFieldByName("age", false), |
| 1003 | + ::testing::Optional(*field3_)); |
| 1004 | + } |
| 1005 | + } |
| 1006 | + }); |
| 1007 | + } |
| 1008 | + |
| 1009 | + for (auto& thread : threads) { |
| 1010 | + thread.join(); |
| 1011 | + } |
| 1012 | +} |
0 commit comments