Skip to content

Commit c25c5a7

Browse files
authored
implement cast_to_variant kernel to cast native types to VariantArray (#8044)
# Which issue does this PR close? We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. - Closes #8043 # Rationale for this change As @Samyak2 suggested on https://github.com/apache/arrow-rs/pull/8021/files#r2249926579, having the ability to convert *FROM* a typed value to a VariantArray will be important For example, in SQL it could be used to cast columns to variant like in `some_column::variant` # What changes are included in this PR? 1. Add `cast_to_variant` kernel to cast native types to `VariantArray` 2. Tests # Are these changes tested? yes # Are there any user-facing changes? New kernel
1 parent 5dd3463 commit c25c5a7

File tree

3 files changed

+395
-0
lines changed

3 files changed

+395
-0
lines changed
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::{VariantArray, VariantArrayBuilder};
19+
use arrow::array::{Array, AsArray};
20+
use arrow::datatypes::{
21+
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
22+
UInt64Type, UInt8Type,
23+
};
24+
use arrow_schema::{ArrowError, DataType};
25+
use parquet_variant::Variant;
26+
27+
/// Convert the input array of a specific primitive type to a `VariantArray`
28+
/// row by row
29+
macro_rules! primitive_conversion {
30+
($t:ty, $input:expr, $builder:expr) => {{
31+
let array = $input.as_primitive::<$t>();
32+
for i in 0..array.len() {
33+
if array.is_null(i) {
34+
$builder.append_null();
35+
continue;
36+
}
37+
$builder.append_variant(Variant::from(array.value(i)));
38+
}
39+
}};
40+
}
41+
42+
/// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you
43+
/// need to convert a specific data type
44+
///
45+
/// # Arguments
46+
/// * `input` - A reference to the input [`Array`] to cast
47+
///
48+
/// # Notes
49+
/// If the input array element is null, the corresponding element in the
50+
/// output `VariantArray` will also be null (not `Variant::Null`).
51+
///
52+
/// # Example
53+
/// ```
54+
/// # use arrow::array::{Array, ArrayRef, Int64Array};
55+
/// # use parquet_variant::Variant;
56+
/// # use parquet_variant_compute::cast_to_variant::cast_to_variant;
57+
/// // input is an Int64Array, which will be cast to a VariantArray
58+
/// let input = Int64Array::from(vec![Some(1), None, Some(3)]);
59+
/// let result = cast_to_variant(&input).unwrap();
60+
/// assert_eq!(result.len(), 3);
61+
/// assert_eq!(result.value(0), Variant::Int64(1));
62+
/// assert!(result.is_null(1)); // note null, not Variant::Null
63+
/// assert_eq!(result.value(2), Variant::Int64(3));
64+
/// ```
65+
pub fn cast_to_variant(input: &dyn Array) -> Result<VariantArray, ArrowError> {
66+
let mut builder = VariantArrayBuilder::new(input.len());
67+
68+
let input_type = input.data_type();
69+
// todo: handle other types like Boolean, Strings, Date, Timestamp, etc.
70+
match input_type {
71+
DataType::Int8 => {
72+
primitive_conversion!(Int8Type, input, builder);
73+
}
74+
DataType::Int16 => {
75+
primitive_conversion!(Int16Type, input, builder);
76+
}
77+
DataType::Int32 => {
78+
primitive_conversion!(Int32Type, input, builder);
79+
}
80+
DataType::Int64 => {
81+
primitive_conversion!(Int64Type, input, builder);
82+
}
83+
DataType::UInt8 => {
84+
primitive_conversion!(UInt8Type, input, builder);
85+
}
86+
DataType::UInt16 => {
87+
primitive_conversion!(UInt16Type, input, builder);
88+
}
89+
DataType::UInt32 => {
90+
primitive_conversion!(UInt32Type, input, builder);
91+
}
92+
DataType::UInt64 => {
93+
primitive_conversion!(UInt64Type, input, builder);
94+
}
95+
DataType::Float32 => {
96+
primitive_conversion!(Float32Type, input, builder);
97+
}
98+
DataType::Float64 => {
99+
primitive_conversion!(Float64Type, input, builder);
100+
}
101+
dt => {
102+
return Err(ArrowError::CastError(format!(
103+
"Unsupported data type for casting to Variant: {dt:?}",
104+
)));
105+
}
106+
};
107+
Ok(builder.build())
108+
}
109+
110+
// TODO do we need a cast_with_options to allow specifying conversion behavior,
111+
// e.g. how to handle overflows, whether to convert to Variant::Null or return
112+
// an error, etc. ?
113+
114+
#[cfg(test)]
115+
mod tests {
116+
use super::*;
117+
use arrow::array::{
118+
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
119+
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
120+
};
121+
use parquet_variant::{Variant, VariantDecimal16};
122+
use std::sync::Arc;
123+
124+
#[test]
125+
fn test_cast_to_variant_int8() {
126+
run_test(
127+
Arc::new(Int8Array::from(vec![
128+
Some(i8::MIN),
129+
None,
130+
Some(-1),
131+
Some(1),
132+
Some(i8::MAX),
133+
])),
134+
vec![
135+
Some(Variant::Int8(i8::MIN)),
136+
None,
137+
Some(Variant::Int8(-1)),
138+
Some(Variant::Int8(1)),
139+
Some(Variant::Int8(i8::MAX)),
140+
],
141+
)
142+
}
143+
144+
#[test]
145+
fn test_cast_to_variant_int16() {
146+
run_test(
147+
Arc::new(Int16Array::from(vec![
148+
Some(i16::MIN),
149+
None,
150+
Some(-1),
151+
Some(1),
152+
Some(i16::MAX),
153+
])),
154+
vec![
155+
Some(Variant::Int16(i16::MIN)),
156+
None,
157+
Some(Variant::Int16(-1)),
158+
Some(Variant::Int16(1)),
159+
Some(Variant::Int16(i16::MAX)),
160+
],
161+
)
162+
}
163+
164+
#[test]
165+
fn test_cast_to_variant_int32() {
166+
run_test(
167+
Arc::new(Int32Array::from(vec![
168+
Some(i32::MIN),
169+
None,
170+
Some(-1),
171+
Some(1),
172+
Some(i32::MAX),
173+
])),
174+
vec![
175+
Some(Variant::Int32(i32::MIN)),
176+
None,
177+
Some(Variant::Int32(-1)),
178+
Some(Variant::Int32(1)),
179+
Some(Variant::Int32(i32::MAX)),
180+
],
181+
)
182+
}
183+
184+
#[test]
185+
fn test_cast_to_variant_int64() {
186+
run_test(
187+
Arc::new(Int64Array::from(vec![
188+
Some(i64::MIN),
189+
None,
190+
Some(-1),
191+
Some(1),
192+
Some(i64::MAX),
193+
])),
194+
vec![
195+
Some(Variant::Int64(i64::MIN)),
196+
None,
197+
Some(Variant::Int64(-1)),
198+
Some(Variant::Int64(1)),
199+
Some(Variant::Int64(i64::MAX)),
200+
],
201+
)
202+
}
203+
204+
#[test]
205+
fn test_cast_to_variant_uint8() {
206+
run_test(
207+
Arc::new(UInt8Array::from(vec![
208+
Some(0),
209+
None,
210+
Some(1),
211+
Some(127),
212+
Some(u8::MAX),
213+
])),
214+
vec![
215+
Some(Variant::Int8(0)),
216+
None,
217+
Some(Variant::Int8(1)),
218+
Some(Variant::Int8(127)),
219+
Some(Variant::Int16(255)), // u8::MAX cannot fit in Int8
220+
],
221+
)
222+
}
223+
224+
#[test]
225+
fn test_cast_to_variant_uint16() {
226+
run_test(
227+
Arc::new(UInt16Array::from(vec![
228+
Some(0),
229+
None,
230+
Some(1),
231+
Some(32767),
232+
Some(u16::MAX),
233+
])),
234+
vec![
235+
Some(Variant::Int16(0)),
236+
None,
237+
Some(Variant::Int16(1)),
238+
Some(Variant::Int16(32767)),
239+
Some(Variant::Int32(65535)), // u16::MAX cannot fit in Int16
240+
],
241+
)
242+
}
243+
244+
#[test]
245+
fn test_cast_to_variant_uint32() {
246+
run_test(
247+
Arc::new(UInt32Array::from(vec![
248+
Some(0),
249+
None,
250+
Some(1),
251+
Some(2147483647),
252+
Some(u32::MAX),
253+
])),
254+
vec![
255+
Some(Variant::Int32(0)),
256+
None,
257+
Some(Variant::Int32(1)),
258+
Some(Variant::Int32(2147483647)),
259+
Some(Variant::Int64(4294967295)), // u32::MAX cannot fit in Int32
260+
],
261+
)
262+
}
263+
264+
#[test]
265+
fn test_cast_to_variant_uint64() {
266+
run_test(
267+
Arc::new(UInt64Array::from(vec![
268+
Some(0),
269+
None,
270+
Some(1),
271+
Some(9223372036854775807),
272+
Some(u64::MAX),
273+
])),
274+
vec![
275+
Some(Variant::Int64(0)),
276+
None,
277+
Some(Variant::Int64(1)),
278+
Some(Variant::Int64(9223372036854775807)),
279+
Some(Variant::Decimal16(
280+
// u64::MAX cannot fit in Int64
281+
VariantDecimal16::try_from(18446744073709551615).unwrap(),
282+
)),
283+
],
284+
)
285+
}
286+
287+
#[test]
288+
fn test_cast_to_variant_float32() {
289+
run_test(
290+
Arc::new(Float32Array::from(vec![
291+
Some(f32::MIN),
292+
None,
293+
Some(-1.5),
294+
Some(0.0),
295+
Some(1.5),
296+
Some(f32::MAX),
297+
])),
298+
vec![
299+
Some(Variant::Float(f32::MIN)),
300+
None,
301+
Some(Variant::Float(-1.5)),
302+
Some(Variant::Float(0.0)),
303+
Some(Variant::Float(1.5)),
304+
Some(Variant::Float(f32::MAX)),
305+
],
306+
)
307+
}
308+
309+
#[test]
310+
fn test_cast_to_variant_float64() {
311+
run_test(
312+
Arc::new(Float64Array::from(vec![
313+
Some(f64::MIN),
314+
None,
315+
Some(-1.5),
316+
Some(0.0),
317+
Some(1.5),
318+
Some(f64::MAX),
319+
])),
320+
vec![
321+
Some(Variant::Double(f64::MIN)),
322+
None,
323+
Some(Variant::Double(-1.5)),
324+
Some(Variant::Double(0.0)),
325+
Some(Variant::Double(1.5)),
326+
Some(Variant::Double(f64::MAX)),
327+
],
328+
)
329+
}
330+
331+
/// Converts the given `Array` to a `VariantArray` and tests the conversion
332+
/// against the expected values. It also tests the handling of nulls by
333+
/// setting one element to null and verifying the output.
334+
fn run_test(values: ArrayRef, expected: Vec<Option<Variant>>) {
335+
// test without nulls
336+
let variant_array = cast_to_variant(&values).unwrap();
337+
assert_eq!(variant_array.len(), expected.len());
338+
for (i, expected_value) in expected.iter().enumerate() {
339+
match expected_value {
340+
Some(value) => {
341+
assert!(!variant_array.is_null(i), "Expected non-null at index {i}");
342+
assert_eq!(variant_array.value(i), *value, "mismatch at index {i}");
343+
}
344+
None => {
345+
assert!(variant_array.is_null(i), "Expected null at index {i}");
346+
}
347+
}
348+
}
349+
}
350+
}

parquet-variant-compute/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
pub mod cast_to_variant;
1819
mod from_json;
1920
mod to_json;
2021
mod variant_array;

0 commit comments

Comments
 (0)