Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions python/pyarrow/src/arrow/python/python_to_arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,14 @@ class PyConverter : public Converter<PyObject*, PyConversionOptions> {
}
};

// Helper function to unwrap extension scalar to its storage scalar
inline const Scalar& GetStorageScalar(const Scalar& scalar) {
if (scalar.type->id() == Type::EXTENSION) {
return *checked_cast<const ExtensionScalar&>(scalar).value;
}
return scalar;
}

template <typename T, typename Enable = void>
class PyPrimitiveConverter;

Expand Down Expand Up @@ -657,7 +665,8 @@ class PyPrimitiveConverter<
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -678,7 +687,8 @@ class PyPrimitiveConverter<
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -704,7 +714,8 @@ class PyPrimitiveConverter<T, enable_if_t<std::is_same<T, FixedSizeBinaryType>::
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->primitive_type_, this->options_, value, view_));
Expand Down Expand Up @@ -741,7 +752,8 @@ class PyPrimitiveConverter<
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->primitive_type_, this->options_, value, view_));
Expand Down Expand Up @@ -785,7 +797,7 @@ class PyDictionaryConverter<U, enable_if_has_c_type<U>>
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->value_builder_->AppendScalar(*scalar, 1);
return this->value_builder_->AppendScalar(GetStorageScalar(*scalar), 1);
} else {
ARROW_ASSIGN_OR_RAISE(auto converted,
PyValue::Convert(this->value_type_, this->options_, value));
Expand All @@ -804,7 +816,7 @@ class PyDictionaryConverter<U, enable_if_has_string_view<U>>
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->value_builder_->AppendScalar(*scalar, 1);
return this->value_builder_->AppendScalar(GetStorageScalar(*scalar), 1);
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->value_type_, this->options_, value, view_));
Expand Down Expand Up @@ -977,7 +989,7 @@ class PyStructConverter : public StructConverter<PyConverter, PyConverterTrait>
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->struct_builder_->AppendScalar(*scalar);
return this->struct_builder_->AppendScalar(GetStorageScalar(*scalar));
}
switch (input_kind_) {
case InputKind::DICT:
Expand Down
63 changes: 63 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,69 @@ def test_uuid_extension():
assert isinstance(array[0], pa.UuidScalar)


def test_array_from_extension_scalars():
# Test unwrap to various storage types and different converters
import datetime

builtin_cases = [
# fixed_size_binary[16] storage
(pa.uuid(), [b"0123456789abcdef"], [
UUID('30313233-3435-3637-3839-616263646566')]),
# int8 storage
(pa.bool8(), [0, 1], [0, 1]),
# string storage
(pa.json_(pa.string()), ['{"a":1}', '{"b":2}'], ['{"a":1}', '{"b":2}']),
# binary storage
(pa.opaque(pa.binary(), "t", "v"), [b"x", b"y"], [b"x", b"y"]),
]
for ext_type, values, expected in builtin_cases:
scalars = [pa.scalar(v, type=ext_type) for v in values]
result = pa.array(scalars, type=ext_type)
assert result.type == ext_type
# TODO: make `expected` pyarrow array so `to_pylist` isn't used, check GH-48241
assert result.to_pylist() == expected

# Custom extension types requiring registration
custom_cases = [
# int8 storage
(TinyIntType(), [1, 2], [1, 2]),
# int64 storage
(IntegerType(), [100, 200], [100, 200]),
# string storage
(LabelType(), ["a", "b"], ["a", "b"]),
# struct storage
(MyStructType(), [{"left": 1, "right": 2}], [{"left": 1, "right": 2}]),
# timestamp storage
(AnnotatedType(pa.timestamp("us"), "ts"),
[datetime.datetime(2023, 1, 1)], [datetime.datetime(2023, 1, 1)]),
# duration storage
(AnnotatedType(pa.duration("s"), "dur"),
[datetime.timedelta(seconds=100)], [datetime.timedelta(seconds=100)]),
# date storage
(AnnotatedType(pa.date32(), "date"),
[datetime.date(2023, 1, 1)], [datetime.date(2023, 1, 1)]),
# float64 storage
(AnnotatedType(pa.float64(), "f"), [1.5, 2.5], [1.5, 2.5]),
# boolean storage
(AnnotatedType(pa.bool_(), "b"), [True, False], [True, False]),
# binary storage
(AnnotatedType(pa.binary(), "bin"), [b"x", b"y"], [b"x", b"y"]),
]
for ext_type, values, expected in custom_cases:
with registered_extension_type(ext_type):
scalars = [pa.scalar(v, type=ext_type) for v in values]
result = pa.array(scalars, type=ext_type)
assert result.type == ext_type
# TODO: make `expected` pyarrow array so `to_pylist` isn't used
assert result.to_pylist() == expected

uuid_type = pa.uuid()
scalars = [pa.scalar(b"0123456789abcdef", type=uuid_type),
pa.scalar(None, type=uuid_type)]
result = pa.array(scalars, type=uuid_type)
assert result[0].is_valid and not result[1].is_valid


def test_tensor_type():
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
Expand Down
Loading