diff --git a/python/pyarrow/src/arrow/python/python_to_arrow.cc b/python/pyarrow/src/arrow/python/python_to_arrow.cc index 139eb1d7f4f..19505778b3e 100644 --- a/python/pyarrow/src/arrow/python/python_to_arrow.cc +++ b/python/pyarrow/src/arrow/python/python_to_arrow.cc @@ -578,6 +578,14 @@ class PyConverter : public Converter { } }; +// Helper function to unwrap extension scalar to its storage scalar +inline const Scalar& GetStorageScalar(const Scalar& scalar) { + if (scalar.type->id() == Type::EXTENSION) { + return *checked_cast(scalar).value; + } + return scalar; +} + template class PyPrimitiveConverter; @@ -657,7 +665,8 @@ class PyPrimitiveConverter< } else if (arrow::py::is_scalar(value)) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, arrow::py::unwrap_scalar(value)); - ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar)); + ARROW_RETURN_NOT_OK( + this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar))); } else { ARROW_ASSIGN_OR_RAISE( auto converted, PyValue::Convert(this->primitive_type_, this->options_, value)); @@ -678,7 +687,8 @@ class PyPrimitiveConverter< } else if (arrow::py::is_scalar(value)) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, arrow::py::unwrap_scalar(value)); - ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar)); + ARROW_RETURN_NOT_OK( + this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar))); } else { ARROW_ASSIGN_OR_RAISE( auto converted, PyValue::Convert(this->primitive_type_, this->options_, value)); @@ -704,7 +714,8 @@ class PyPrimitiveConverter:: } else if (arrow::py::is_scalar(value)) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, arrow::py::unwrap_scalar(value)); - ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar)); + ARROW_RETURN_NOT_OK( + this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar))); } else { ARROW_RETURN_NOT_OK( PyValue::Convert(this->primitive_type_, this->options_, value, view_)); @@ -741,7 +752,8 @@ class PyPrimitiveConverter< } else if (arrow::py::is_scalar(value)) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, arrow::py::unwrap_scalar(value)); - ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar)); + ARROW_RETURN_NOT_OK( + this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar))); } else { ARROW_RETURN_NOT_OK( PyValue::Convert(this->primitive_type_, this->options_, value, view_)); @@ -785,7 +797,7 @@ class PyDictionaryConverter> } else if (arrow::py::is_scalar(value)) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, arrow::py::unwrap_scalar(value)); - return this->value_builder_->AppendScalar(*scalar, 1); + return this->value_builder_->AppendScalar(GetStorageScalar(*scalar), 1); } else { ARROW_ASSIGN_OR_RAISE(auto converted, PyValue::Convert(this->value_type_, this->options_, value)); @@ -804,7 +816,7 @@ class PyDictionaryConverter> } else if (arrow::py::is_scalar(value)) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, arrow::py::unwrap_scalar(value)); - return this->value_builder_->AppendScalar(*scalar, 1); + return this->value_builder_->AppendScalar(GetStorageScalar(*scalar), 1); } else { ARROW_RETURN_NOT_OK( PyValue::Convert(this->value_type_, this->options_, value, view_)); @@ -977,7 +989,7 @@ class PyStructConverter : public StructConverter } else if (arrow::py::is_scalar(value)) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, arrow::py::unwrap_scalar(value)); - return this->struct_builder_->AppendScalar(*scalar); + return this->struct_builder_->AppendScalar(GetStorageScalar(*scalar)); } switch (input_kind_) { case InputKind::DICT: diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index ebac37e862b..c80d5f915e1 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1399,6 +1399,69 @@ def test_uuid_extension(): assert isinstance(array[0], pa.UuidScalar) +def test_array_from_extension_scalars(): + # Test unwrap to various storage types and different converters + import datetime + + builtin_cases = [ + # fixed_size_binary[16] storage + (pa.uuid(), [b"0123456789abcdef"], [ + UUID('30313233-3435-3637-3839-616263646566')]), + # int8 storage + (pa.bool8(), [0, 1], [0, 1]), + # string storage + (pa.json_(pa.string()), ['{"a":1}', '{"b":2}'], ['{"a":1}', '{"b":2}']), + # binary storage + (pa.opaque(pa.binary(), "t", "v"), [b"x", b"y"], [b"x", b"y"]), + ] + for ext_type, values, expected in builtin_cases: + scalars = [pa.scalar(v, type=ext_type) for v in values] + result = pa.array(scalars, type=ext_type) + assert result.type == ext_type + # TODO: make `expected` pyarrow array so `to_pylist` isn't used, check GH-48241 + assert result.to_pylist() == expected + + # Custom extension types requiring registration + custom_cases = [ + # int8 storage + (TinyIntType(), [1, 2], [1, 2]), + # int64 storage + (IntegerType(), [100, 200], [100, 200]), + # string storage + (LabelType(), ["a", "b"], ["a", "b"]), + # struct storage + (MyStructType(), [{"left": 1, "right": 2}], [{"left": 1, "right": 2}]), + # timestamp storage + (AnnotatedType(pa.timestamp("us"), "ts"), + [datetime.datetime(2023, 1, 1)], [datetime.datetime(2023, 1, 1)]), + # duration storage + (AnnotatedType(pa.duration("s"), "dur"), + [datetime.timedelta(seconds=100)], [datetime.timedelta(seconds=100)]), + # date storage + (AnnotatedType(pa.date32(), "date"), + [datetime.date(2023, 1, 1)], [datetime.date(2023, 1, 1)]), + # float64 storage + (AnnotatedType(pa.float64(), "f"), [1.5, 2.5], [1.5, 2.5]), + # boolean storage + (AnnotatedType(pa.bool_(), "b"), [True, False], [True, False]), + # binary storage + (AnnotatedType(pa.binary(), "bin"), [b"x", b"y"], [b"x", b"y"]), + ] + for ext_type, values, expected in custom_cases: + with registered_extension_type(ext_type): + scalars = [pa.scalar(v, type=ext_type) for v in values] + result = pa.array(scalars, type=ext_type) + assert result.type == ext_type + # TODO: make `expected` pyarrow array so `to_pylist` isn't used + assert result.to_pylist() == expected + + uuid_type = pa.uuid() + scalars = [pa.scalar(b"0123456789abcdef", type=uuid_type), + pa.scalar(None, type=uuid_type)] + result = pa.array(scalars, type=uuid_type) + assert result[0].is_valid and not result[1].is_valid + + def test_tensor_type(): tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3]) assert tensor_type.extension_name == "arrow.fixed_shape_tensor"