diff --git a/native/spark-expr/src/json_funcs/to_json.rs b/native/spark-expr/src/json_funcs/to_json.rs index 46b87a40c7..74e772936c 100644 --- a/native/spark-expr/src/json_funcs/to_json.rs +++ b/native/spark-expr/src/json_funcs/to_json.rs @@ -124,12 +124,19 @@ fn array_to_json_string(arr: &Arc, timezone: &str) -> Result() { struct_to_json(struct_array, timezone) } else { - spark_cast( + let array = spark_cast( ColumnarValue::Array(Arc::clone(arr)), &DataType::Utf8, &SparkCastOptions::new(EvalMode::Legacy, timezone, false), )? - .into_array(arr.len()) + .into_array(arr.len())?; + + let string_array = array + .as_any() + .downcast_ref::() + .expect("Utf8 array"); + + Ok(normalize_special_floats(string_array)) } } @@ -181,6 +188,23 @@ fn escape_string(input: &str) -> String { escaped_string } +fn normalize_special_floats(arr: &StringArray) -> ArrayRef { + let mut builder = StringBuilder::with_capacity(arr.len(), arr.len() * 8); + + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + match arr.value(i) { + "Infinity" | "-Infinity" | "NaN" => builder.append_null(), + v => builder.append_value(v), + } + } + } + + Arc::new(builder.finish()) +} + fn struct_to_json(array: &StructArray, timezone: &str) -> Result { // get field names and escape any quotes let field_names: Vec = array @@ -331,6 +355,34 @@ mod test { Ok(()) } + #[test] + fn test_to_json_infinity() -> Result<()> { + use arrow::array::{Float64Array, StructArray}; + use arrow::datatypes::{DataType, Field}; + + let values: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(f64::NAN), + Some(1.5), + ])); + + let struct_array = StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Float64, true)), + values, + )]); + + let json = struct_to_json(&struct_array, "UTC")?; + let json = json.as_any().downcast_ref::().unwrap(); + + assert_eq!(r#"{}"#, json.value(0)); + assert_eq!(r#"{}"#, json.value(1)); + assert_eq!(r#"{}"#, json.value(2)); + assert_eq!(r#"{"a":1.5}"#, json.value(3)); + + Ok(()) + } + fn create_ints() -> Arc> { Arc::new(Int32Array::from(vec![ Some(123),