diff --git a/cmake_modules/IcebergThirdpartyToolchain.cmake b/cmake_modules/IcebergThirdpartyToolchain.cmake index f962ae6da..e361f1e2f 100644 --- a/cmake_modules/IcebergThirdpartyToolchain.cmake +++ b/cmake_modules/IcebergThirdpartyToolchain.cmake @@ -23,9 +23,9 @@ set(ICEBERG_ARROW_INSTALL_INTERFACE_LIBS) # ---------------------------------------------------------------------- # Versions and URLs for toolchain builds -set(ICEBERG_ARROW_BUILD_VERSION "18.1.0") +set(ICEBERG_ARROW_BUILD_VERSION "19.0.1") set(ICEBERG_ARROW_BUILD_SHA256_CHECKSUM - "2dc8da5f8796afe213ecc5e5aba85bb82d91520eff3cf315784a52d0fa61d7fc") + "acb76266e8b0c2fbb7eb15d542fbb462a73b3fd1e32b80fad6c2fafd95a51160") if(DEFINED ENV{ICEBERG_ARROW_URL}) set(ARROW_SOURCE_URL "$ENV{ICEBERG_ARROW_URL}") diff --git a/src/iceberg/CMakeLists.txt b/src/iceberg/CMakeLists.txt index 64bc0efc2..817652246 100644 --- a/src/iceberg/CMakeLists.txt +++ b/src/iceberg/CMakeLists.txt @@ -22,6 +22,7 @@ set(ICEBERG_SOURCES demo.cc schema.cc schema_field.cc + schema_internal.cc type.cc) set(ICEBERG_STATIC_BUILD_INTERFACE_LIBS) diff --git a/src/iceberg/arrow_c_data.h b/src/iceberg/arrow_c_data.h index 43c2adbd8..7a4618ecd 100644 --- a/src/iceberg/arrow_c_data.h +++ b/src/iceberg/arrow_c_data.h @@ -29,10 +29,15 @@ #include +extern "C" { + #ifndef ARROW_C_DATA_INTERFACE # define ARROW_C_DATA_INTERFACE -extern "C" { +# define ARROW_FLAG_DICTIONARY_ORDERED 1 +# define ARROW_FLAG_NULLABLE 2 +# define ARROW_FLAG_MAP_KEYS_SORTED 4 + struct ArrowSchema { // Array type description const char* format; @@ -66,6 +71,6 @@ struct ArrowArray { void* private_data; }; -} // extern "C" - #endif // ARROW_C_DATA_INTERFACE + +} // extern "C" diff --git a/src/iceberg/error.h b/src/iceberg/error.h index 77414f900..066ef87f0 100644 --- a/src/iceberg/error.h +++ b/src/iceberg/error.h @@ -32,6 +32,8 @@ enum class ErrorKind { kAlreadyExists, kNoSuchTable, kCommitStateUnknown, + kInvalidSchema, + kInvalidArgument, }; /// \brief Error with a kind and a message. diff --git a/src/iceberg/schema_field.h b/src/iceberg/schema_field.h index e37c2d2d8..3fde248f7 100644 --- a/src/iceberg/schema_field.h +++ b/src/iceberg/schema_field.h @@ -64,7 +64,7 @@ class ICEBERG_EXPORT SchemaField : public iceberg::util::Formattable { /// \brief Get whether the field is optional. [[nodiscard]] bool optional() const; - [[nodiscard]] std::string ToString() const; + [[nodiscard]] std::string ToString() const override; friend bool operator==(const SchemaField& lhs, const SchemaField& rhs) { return lhs.Equals(rhs); diff --git a/src/iceberg/schema_internal.cc b/src/iceberg/schema_internal.cc new file mode 100644 index 000000000..7ca48d15d --- /dev/null +++ b/src/iceberg/schema_internal.cc @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/schema_internal.h" + +#include +#include +#include + +#include "iceberg/expected.h" +#include "iceberg/schema.h" + +namespace iceberg { + +namespace { + +constexpr const char* kArrowExtensionName = "ARROW:extension:name"; +constexpr const char* kArrowExtensionMetadata = "ARROW:extension:metadata"; + +// Convert an Iceberg type to Arrow schema. Return value is Nanoarrow error code. +ArrowErrorCode ToArrowSchema(const Type& type, bool optional, std::string_view name, + std::optional field_id, ArrowSchema* schema) { + ArrowBuffer metadata_buffer; + NANOARROW_RETURN_NOT_OK(ArrowMetadataBuilderInit(&metadata_buffer, nullptr)); + if (field_id.has_value()) { + NANOARROW_RETURN_NOT_OK(ArrowMetadataBuilderAppend( + &metadata_buffer, ArrowCharView(std::string(kFieldIdKey).c_str()), + ArrowCharView(std::to_string(field_id.value()).c_str()))); + } + + switch (type.type_id()) { + case TypeId::kStruct: { + NANOARROW_RETURN_NOT_OK(ArrowSchemaInitFromType(schema, NANOARROW_TYPE_STRUCT)); + + const auto& struct_type = static_cast(type); + const auto& fields = struct_type.fields(); + NANOARROW_RETURN_NOT_OK(ArrowSchemaAllocateChildren(schema, fields.size())); + + for (size_t i = 0; i < fields.size(); i++) { + const auto& field = fields[i]; + NANOARROW_RETURN_NOT_OK(ToArrowSchema(*field.type(), field.optional(), + field.name(), field.field_id(), + schema->children[i])); + } + } break; + case TypeId::kList: { + NANOARROW_RETURN_NOT_OK(ArrowSchemaInitFromType(schema, NANOARROW_TYPE_LIST)); + + const auto& list_type = static_cast(type); + const auto& elem_field = list_type.fields()[0]; + NANOARROW_RETURN_NOT_OK(ToArrowSchema(*elem_field.type(), elem_field.optional(), + elem_field.name(), elem_field.field_id(), + schema->children[0])); + } break; + case TypeId::kMap: { + NANOARROW_RETURN_NOT_OK(ArrowSchemaInitFromType(schema, NANOARROW_TYPE_MAP)); + + const auto& map_type = static_cast(type); + const auto& key_field = map_type.key(); + const auto& value_field = map_type.value(); + NANOARROW_RETURN_NOT_OK(ToArrowSchema(*key_field.type(), key_field.optional(), + key_field.name(), key_field.field_id(), + schema->children[0]->children[0])); + NANOARROW_RETURN_NOT_OK(ToArrowSchema(*value_field.type(), value_field.optional(), + value_field.name(), value_field.field_id(), + schema->children[0]->children[1])); + } break; + case TypeId::kBoolean: + NANOARROW_RETURN_NOT_OK(ArrowSchemaInitFromType(schema, NANOARROW_TYPE_BOOL)); + break; + case TypeId::kInt: + NANOARROW_RETURN_NOT_OK(ArrowSchemaInitFromType(schema, NANOARROW_TYPE_INT32)); + break; + case TypeId::kLong: + NANOARROW_RETURN_NOT_OK(ArrowSchemaInitFromType(schema, NANOARROW_TYPE_INT64)); + break; + case TypeId::kFloat: + NANOARROW_RETURN_NOT_OK(ArrowSchemaInitFromType(schema, NANOARROW_TYPE_FLOAT)); + break; + case TypeId::kDouble: + NANOARROW_RETURN_NOT_OK(ArrowSchemaInitFromType(schema, NANOARROW_TYPE_DOUBLE)); + break; + case TypeId::kDecimal: { + ArrowSchemaInit(schema); + const auto& decimal_type = static_cast(type); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeDecimal(schema, NANOARROW_TYPE_DECIMAL128, + decimal_type.precision(), + decimal_type.scale())); + } break; + case TypeId::kDate: + NANOARROW_RETURN_NOT_OK(ArrowSchemaInitFromType(schema, NANOARROW_TYPE_DATE32)); + break; + case TypeId::kTime: { + ArrowSchemaInit(schema); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeDateTime(schema, NANOARROW_TYPE_TIME64, + NANOARROW_TIME_UNIT_MICRO, + /*timezone=*/nullptr)); + } break; + case TypeId::kTimestamp: { + ArrowSchemaInit(schema); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeDateTime(schema, NANOARROW_TYPE_TIMESTAMP, + NANOARROW_TIME_UNIT_MICRO, + /*timezone=*/nullptr)); + } break; + case TypeId::kTimestampTz: { + ArrowSchemaInit(schema); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeDateTime( + schema, NANOARROW_TYPE_TIMESTAMP, NANOARROW_TIME_UNIT_MICRO, "UTC")); + } break; + case TypeId::kString: + NANOARROW_RETURN_NOT_OK(ArrowSchemaInitFromType(schema, NANOARROW_TYPE_STRING)); + break; + case TypeId::kBinary: + NANOARROW_RETURN_NOT_OK(ArrowSchemaInitFromType(schema, NANOARROW_TYPE_BINARY)); + break; + case TypeId::kFixed: { + ArrowSchemaInit(schema); + const auto& fixed_type = static_cast(type); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeFixedSize( + schema, NANOARROW_TYPE_FIXED_SIZE_BINARY, fixed_type.length())); + } break; + case TypeId::kUuid: { + ArrowSchemaInit(schema); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeFixedSize( + schema, NANOARROW_TYPE_FIXED_SIZE_BINARY, /*fixed_size=*/16)); + NANOARROW_RETURN_NOT_OK( + ArrowMetadataBuilderAppend(&metadata_buffer, ArrowCharView(kArrowExtensionName), + ArrowCharView("arrow.uuid"))); + } break; + } + + if (!name.empty()) { + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetName(schema, std::string(name).c_str())); + } + + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetMetadata( + schema, reinterpret_cast(metadata_buffer.data))); + ArrowBufferReset(&metadata_buffer); + + if (optional) { + schema->flags |= ARROW_FLAG_NULLABLE; + } else { + schema->flags &= ~ARROW_FLAG_NULLABLE; + } + + return NANOARROW_OK; +} + +} // namespace + +expected ToArrowSchema(const Schema& schema, ArrowSchema* out) { + if (out == nullptr) [[unlikely]] { + return unexpected{{.kind = ErrorKind::kInvalidArgument, + .message = "Output Arrow schema cannot be null"}}; + } + + if (ArrowErrorCode errorCode = ToArrowSchema(schema, /*optional=*/false, /*name=*/"", + /*field_id=*/std::nullopt, out); + errorCode != NANOARROW_OK) { + return unexpected{ + {.kind = ErrorKind::kInvalidSchema, + .message = std::format( + "Failed to convert Iceberg schema to Arrow schema, error code: {}", + errorCode)}}; + } + + return {}; +} + +expected, Error> FromArrowSchema(const ArrowSchema& schema, + int32_t schema_id) { + // TODO(wgtmac): Implement this + return unexpected{ + {.kind = ErrorKind::kInvalidSchema, .message = "Not implemented yet"}}; +} + +} // namespace iceberg diff --git a/src/iceberg/schema_internal.h b/src/iceberg/schema_internal.h new file mode 100644 index 000000000..164044755 --- /dev/null +++ b/src/iceberg/schema_internal.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +#include + +#include "iceberg/error.h" +#include "iceberg/expected.h" +#include "iceberg/type_fwd.h" + +namespace iceberg { + +// Apache Arrow C++ uses "PARQUET:field_id" to store field IDs for Parquet. +// Here we follow a similar convention for Iceberg but we might also add +// "PARQUET:field_id" in the future once we implement a Parquet writer. +constexpr std::string_view kFieldIdKey = "ICEBERG:field_id"; + +/// \brief Convert an Iceberg schema to an Arrow schema. +/// +/// \param[in] schema The Iceberg schema to convert. +/// \param[out] out The Arrow schema to convert to. +/// \return An error if the conversion fails. +expected ToArrowSchema(const Schema& schema, ArrowSchema* out); + +/// \brief Convert an Arrow schema to an Iceberg schema. +/// +/// \param[in] schema The Arrow schema to convert. +/// \param[in] schema_id The schema ID of the Iceberg schema. +/// \return The Iceberg schema or an error if the conversion fails. +expected, Error> FromArrowSchema(const ArrowSchema& schema, + int32_t schema_id); + +} // namespace iceberg diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5fbff2015..76006e0f9 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -30,18 +30,19 @@ add_test(NAME schema_test COMMAND schema_test) add_executable(expected_test) target_sources(expected_test PRIVATE expected_test.cc) -target_link_libraries(expected_test PRIVATE iceberg_static GTest::gtest_main) +target_link_libraries(expected_test PRIVATE iceberg_static GTest::gtest_main GTest::gmock) add_test(NAME expected_test COMMAND expected_test) if(ICEBERG_BUILD_BUNDLE) add_executable(avro_test) target_sources(avro_test PRIVATE avro_test.cc) - target_link_libraries(avro_test PRIVATE iceberg_bundle_static GTest::gtest_main) + target_link_libraries(avro_test PRIVATE iceberg_bundle_static GTest::gtest_main + GTest::gmock) add_test(NAME avro_test COMMAND avro_test) add_executable(arrow_test) target_sources(arrow_test PRIVATE arrow_test.cc) target_link_libraries(arrow_test PRIVATE iceberg_bundle_static Arrow::arrow_static - GTest::gtest_main) + GTest::gtest_main GTest::gmock) add_test(NAME arrow_test COMMAND arrow_test) endif() diff --git a/test/arrow_test.cc b/test/arrow_test.cc index 1d730fc49..b48df086e 100644 --- a/test/arrow_test.cc +++ b/test/arrow_test.cc @@ -17,12 +17,20 @@ * under the License. */ +#include + #include #include +#include #include +#include +#include #include #include "iceberg/arrow_c_data_internal.h" +#include "iceberg/schema.h" +#include "iceberg/schema_internal.h" +#include "matchers.h" namespace iceberg { @@ -59,4 +67,196 @@ TEST(ArrowCDataTest, CheckArrowSchemaAndArrayByNanoarrow) { EXPECT_EQ(name_column->GetScalar(2).ValueOrDie()->ToString(), "c"); } +struct ToArrowSchemaParam { + std::shared_ptr iceberg_type; + bool optional = true; + std::shared_ptr arrow_type; +}; + +class ToArrowSchemaTest : public ::testing::TestWithParam {}; + +TEST_P(ToArrowSchemaTest, PrimitiveType) { + constexpr std::string_view kFieldName = "foo"; + constexpr int32_t kFieldId = 1024; + const auto& param = GetParam(); + Schema schema( + /*schema_id=*/0, + {param.optional ? SchemaField::MakeOptional(kFieldId, std::string(kFieldName), + param.iceberg_type) + : SchemaField::MakeRequired(kFieldId, std::string(kFieldName), + param.iceberg_type)}); + ArrowSchema arrow_schema; + ASSERT_THAT(ToArrowSchema(schema, &arrow_schema), IsOk()); + + auto imported_schema = ::arrow::ImportSchema(&arrow_schema).ValueOrDie(); + ASSERT_EQ(imported_schema->num_fields(), 1); + + auto field = imported_schema->field(0); + ASSERT_EQ(field->name(), kFieldName); + ASSERT_EQ(field->nullable(), param.optional); + ASSERT_TRUE(field->type()->Equals(param.arrow_type)); + + auto metadata = field->metadata(); + ASSERT_TRUE(metadata->Contains(kFieldIdKey)); + ASSERT_EQ(metadata->Get(kFieldIdKey), std::to_string(kFieldId)); +} + +INSTANTIATE_TEST_SUITE_P( + SchemaConversion, ToArrowSchemaTest, + ::testing::Values( + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .optional = false, + .arrow_type = ::arrow::boolean()}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .optional = true, + .arrow_type = ::arrow::int32()}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .arrow_type = ::arrow::int64()}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .arrow_type = ::arrow::float32()}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .arrow_type = ::arrow::float64()}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(10, 2), + .arrow_type = ::arrow::decimal128(10, 2)}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .arrow_type = ::arrow::date32()}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .arrow_type = ::arrow::time64(arrow::TimeUnit::MICRO)}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .arrow_type = ::arrow::timestamp(arrow::TimeUnit::MICRO)}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .arrow_type = ::arrow::timestamp(arrow::TimeUnit::MICRO)}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .arrow_type = ::arrow::utf8()}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .arrow_type = ::arrow::binary()}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(), + .arrow_type = ::arrow::extension::uuid()}, + ToArrowSchemaParam{.iceberg_type = std::make_shared(20), + .arrow_type = ::arrow::fixed_size_binary(20)})); + +namespace { + +void CheckArrowField(const ::arrow::Field& field, ::arrow::Type::type type_id, + std::string_view name, bool nullable, int32_t field_id) { + ASSERT_EQ(field.name(), name); + ASSERT_EQ(field.nullable(), nullable); + ASSERT_EQ(field.type()->id(), type_id); + + auto metadata = field.metadata(); + ASSERT_TRUE(metadata != nullptr); + ASSERT_TRUE(metadata->Contains(kFieldIdKey)); + ASSERT_EQ(metadata->Get(kFieldIdKey), std::to_string(field_id)); +} + +} // namespace + +TEST(ToArrowSchemaTest, StructType) { + constexpr int32_t kStructFieldId = 1; + constexpr int32_t kIntFieldId = 2; + constexpr int32_t kStrFieldId = 3; + + constexpr std::string_view kStructFieldName = "struct_field"; + constexpr std::string_view kIntFieldName = "int_field"; + constexpr std::string_view kStrFieldName = "str_field"; + + auto struct_type = std::make_shared(std::vector{ + SchemaField::MakeRequired(kIntFieldId, std::string(kIntFieldName), + std::make_shared()), + SchemaField::MakeOptional(kStrFieldId, std::string(kStrFieldName), + std::make_shared())}); + Schema schema( + /*schema_id=*/0, {SchemaField::MakeRequired( + kStructFieldId, std::string(kStructFieldName), struct_type)}); + + ArrowSchema arrow_schema; + ASSERT_THAT(ToArrowSchema(schema, &arrow_schema), IsOk()); + + auto imported_schema = ::arrow::ImportSchema(&arrow_schema).ValueOrDie(); + ASSERT_EQ(imported_schema->num_fields(), 1); + + auto field = imported_schema->field(0); + ASSERT_NO_FATAL_FAILURE(CheckArrowField(*field, ::arrow::Type::STRUCT, kStructFieldName, + /*nullable=*/false, kStructFieldId)); + + auto struct_field = std::static_pointer_cast<::arrow::StructType>(field->type()); + ASSERT_EQ(struct_field->num_fields(), 2); + + ASSERT_NO_FATAL_FAILURE(CheckArrowField(*struct_field->field(0), ::arrow::Type::INT32, + kIntFieldName, /*nullable=*/false, + kIntFieldId)); + ASSERT_NO_FATAL_FAILURE(CheckArrowField(*struct_field->field(1), ::arrow::Type::STRING, + kStrFieldName, /*nullable=*/true, kStrFieldId)); +} + +TEST(ToArrowSchemaTest, ListType) { + constexpr std::string_view kListFieldName = "list_field"; + constexpr std::string_view kElemFieldName = "element"; + constexpr int32_t kListFieldId = 1; + constexpr int32_t kElemFieldId = 2; + + auto list_type = std::make_shared(SchemaField::MakeOptional( + kElemFieldId, std::string(kElemFieldName), std::make_shared())); + Schema schema( + /*schema_id=*/0, + {SchemaField::MakeRequired(kListFieldId, std::string(kListFieldName), list_type)}); + + ArrowSchema arrow_schema; + ASSERT_THAT(ToArrowSchema(schema, &arrow_schema), IsOk()); + + auto imported_schema = ::arrow::ImportSchema(&arrow_schema).ValueOrDie(); + ASSERT_EQ(imported_schema->num_fields(), 1); + + auto field = imported_schema->field(0); + ASSERT_NO_FATAL_FAILURE(CheckArrowField(*field, ::arrow::Type::LIST, kListFieldName, + /*nullable=*/false, kListFieldId)); + + auto list_field = std::static_pointer_cast<::arrow::ListType>(field->type()); + ASSERT_NO_FATAL_FAILURE(CheckArrowField(*list_field->value_field(), + ::arrow::Type::INT64, kElemFieldName, + /*nullable=*/true, kElemFieldId)); +} + +TEST(ToArrowSchemaTest, MapType) { + constexpr std::string_view kMapFieldName = "map_field"; + constexpr std::string_view kKeyFieldName = "key"; + constexpr std::string_view kValueFieldName = "value"; + + constexpr int32_t kFieldId = 1; + constexpr int32_t kKeyFieldId = 2; + constexpr int32_t kValueFieldId = 3; + + auto map_type = std::make_shared( + SchemaField::MakeRequired(kKeyFieldId, std::string(kKeyFieldName), + std::make_shared()), + SchemaField::MakeOptional(kValueFieldId, std::string(kValueFieldName), + std::make_shared())); + + Schema schema( + /*schema_id=*/0, + {SchemaField::MakeRequired(kFieldId, std::string(kMapFieldName), map_type)}); + + ArrowSchema arrow_schema; + ASSERT_THAT(ToArrowSchema(schema, &arrow_schema), IsOk()); + + auto imported_schema = ::arrow::ImportSchema(&arrow_schema).ValueOrDie(); + ASSERT_EQ(imported_schema->num_fields(), 1); + + auto field = imported_schema->field(0); + ASSERT_NO_FATAL_FAILURE(CheckArrowField(*field, ::arrow::Type::MAP, kMapFieldName, + /*nullable=*/false, kFieldId)); + + auto map_field = std::static_pointer_cast<::arrow::MapType>(field->type()); + + auto key_field = map_field->key_field(); + ASSERT_NO_FATAL_FAILURE(CheckArrowField(*key_field, ::arrow::Type::STRING, + kKeyFieldName, + /*nullable=*/false, kKeyFieldId)); + + auto value_field = map_field->item_field(); + ASSERT_NO_FATAL_FAILURE(CheckArrowField(*value_field, ::arrow::Type::INT32, + kValueFieldName, + /*nullable=*/true, kValueFieldId)); +} + } // namespace iceberg diff --git a/test/matchers.h b/test/matchers.h new file mode 100644 index 000000000..eca3b6fb2 --- /dev/null +++ b/test/matchers.h @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include + +#include "iceberg/error.h" + +/* + * \brief Define custom matchers for expected values + * + * Example usage of these matchers: + * + * Basic assertions: + * + * // Check that a result is ok + * EXPECT_THAT(result, IsOk()); + * + * // Check that a result is an error of a specific kind + * EXPECT_THAT(result, IsError(ErrorKind::kNoSuchTable)); + * + * // Check that an error message contains a specific substring + * EXPECT_THAT(result, HasErrorMessage("table not found")); + * + * Value inspection: + * + * // Check that a result has a value that equals 42 + * EXPECT_THAT(result, HasValue(42)); + * + * // Check that a result has a value that satisfies a complex condition + * EXPECT_THAT(result, HasValue(AllOf(Gt(10), Lt(50)))); + * + * Combined assertions: + * + * // Check that the result value has a specific property + * EXPECT_THAT(result, ResultIs(Property(&MyType::name, "example"))); + * + * // Check that the error matches specific criteria + * EXPECT_THAT(result, ErrorIs(AllOf( + * Property(&Error::kind, ErrorKind::kNoSuchTable), + * Property(&Error::message, HasSubstr("table not found")) + * ))); + */ + +namespace iceberg { + +// IsOk matcher that checks if the expected value has a value (not an error) +MATCHER(IsOk, "is an Ok result") { + if (arg.has_value()) { + return true; + } + *result_listener << "which contains error: " << arg.error().message; + return false; +} + +// IsError matcher that checks if the expected value contains an error +MATCHER_P(IsError, kind, "is an Error with the specified kind") { + if (!arg.has_value()) { + if (arg.error().kind == kind) { + return true; + } + *result_listener << "which contains error kind " << static_cast(arg.error().kind) + << " but expected " << static_cast(kind) + << ", message: " << arg.error().message; + return false; + } + *result_listener << "which is not an error but a value"; + return false; +} + +// HasErrorMessage matcher that checks if the expected value contains an error with a +// specific message or substring +MATCHER_P(HasErrorMessage, message_substr, + "is an Error with message containing the substring") { + if (!arg.has_value()) { + if (arg.error().message.find(message_substr) != std::string::npos) { + return true; + } + *result_listener << "which contains error with message '" << arg.error().message + << "' that doesn't contain '" << message_substr << "'"; + return false; + } + *result_listener << "which is not an error but a value"; + return false; +} + +// HasValue matcher that checks if the expected value contains a value that matches the +// given matcher +template +class HasValueMatcher { + public: + explicit HasValueMatcher(MatcherT matcher) : matcher_(std::move(matcher)) {} + + template + bool MatchAndExplain(const T& value, + ::testing::MatchResultListener* result_listener) const { + if (!value.has_value()) { + *result_listener << "which is an error: " << value.error().message; + return false; + } + + return ::testing::MatcherCast(matcher_) + .MatchAndExplain(*value, result_listener); + } + + void DescribeTo(std::ostream* os) const { + *os << "has a value that "; + matcher_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "does not have a value that "; + matcher_.DescribeTo(os); + } + + private: + MatcherT matcher_; +}; + +// Factory function for HasValueMatcher +template +auto HasValue(MatcherT&& matcher) { + return ::testing::MakePolymorphicMatcher( + HasValueMatcher>(std::forward(matcher))); +} + +// Overload for the common case where we just want to check for presence of any value +inline auto HasValue() { return IsOk(); } + +// Matcher that checks an expected value against an expected value and a matcher +template +class ResultMatcher { + public: + explicit ResultMatcher(bool should_have_value, MatcherT matcher) + : should_have_value_(should_have_value), matcher_(std::move(matcher)) {} + + template + bool MatchAndExplain(const T& value, + ::testing::MatchResultListener* result_listener) const { + if (value.has_value() != should_have_value_) { + if (should_have_value_) { + *result_listener << "which is an error: " << value.error().message; + } else { + *result_listener << "which is a value, not an error"; + } + return false; + } + + if (should_have_value_) { + return ::testing::MatcherCast(matcher_) + .MatchAndExplain(*value, result_listener); + } else { + return ::testing::MatcherCast(matcher_) + .MatchAndExplain(value.error(), result_listener); + } + } + + void DescribeTo(std::ostream* os) const { + if (should_have_value_) { + *os << "has a value that "; + } else { + *os << "has an error that "; + } + matcher_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const { + if (should_have_value_) { + *os << "does not have a value that "; + } else { + *os << "does not have an error that "; + } + matcher_.DescribeTo(os); + } + + private: + bool should_have_value_; + MatcherT matcher_; +}; + +// Factory function for ResultMatcher for values +template +auto ResultIs(MatcherT&& matcher) { + return ::testing::MakePolymorphicMatcher( + ResultMatcher>(true, std::forward(matcher))); +} + +// Factory function for ResultMatcher for errors +template +auto ErrorIs(MatcherT&& matcher) { + return ::testing::MakePolymorphicMatcher( + ResultMatcher>(false, std::forward(matcher))); +} + +} // namespace iceberg