Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public CometListVector(

@Override
public ColumnarArray getArray(int i) {
if (isNullAt(i)) return null;
int start = listVector.getOffsetBuffer().getInt(i * ListVector.OFFSET_WIDTH);
int end = listVector.getOffsetBuffer().getInt((i + 1) * ListVector.OFFSET_WIDTH);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public CometMapVector(

@Override
public ColumnarMap getMap(int i) {
if (isNullAt(i)) return null;
int start = mapVector.getOffsetBuffer().getInt(i * MapVector.OFFSET_WIDTH);
int end = mapVector.getOffsetBuffer().getInt((i + 1) * MapVector.OFFSET_WIDTH);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ public double getDouble(int rowId) {

@Override
public UTF8String getUTF8String(int rowId) {
if (isNullAt(rowId)) return null;
if (!isBaseFixedWidthVector) {
BaseVariableWidthVector varWidthVector = (BaseVariableWidthVector) valueVector;
long offsetBufferAddress = varWidthVector.getOffsetBuffer().memoryAddress();
Expand All @@ -147,6 +148,7 @@ public UTF8String getUTF8String(int rowId) {

@Override
public byte[] getBinary(int rowId) {
if (isNullAt(rowId)) return null;
int offset;
int length;
if (valueVector instanceof BaseVariableWidthVector) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public boolean isFixedLength() {

@Override
public Decimal getDecimal(int i, int precision, int scale) {
if (isNullAt(i)) return null;
if (!useDecimal128 && precision <= Decimal.MAX_INT_DIGITS() && type instanceof IntegerType) {
return createDecimal(getInt(i), precision, scale);
} else if (precision <= Decimal.MAX_LONG_DIGITS()) {
Expand Down
184 changes: 113 additions & 71 deletions native/spark-expr/src/array_funcs/array_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
// under the License.

use arrow::array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::datatypes::{DataType, Schema};
use arrow::{
array::{as_primitive_array, Capacities, MutableArrayData},
buffer::{NullBuffer, OffsetBuffer},
datatypes::ArrowNativeType,
record_batch::RecordBatch,
};
use datafusion::common::{
Expand Down Expand Up @@ -198,114 +197,124 @@ fn array_insert<O: OffsetSizeTrait>(
pos_array: &ArrayRef,
legacy_mode: bool,
) -> DataFusionResult<ColumnarValue> {
// The code is based on the implementation of the array_append from the Apache DataFusion
// https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513
//
// This code is also based on the implementation of the array_insert from the Apache Spark
// https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713
// Implementation aligned with Arrow's half-open offset ranges and Spark semantics.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The version fixed by ChatGPT :)


let values = list_array.values();
let offsets = list_array.offsets();
let values_data = values.to_data();
let item_data = items_array.to_data();

// Estimate capacity (original values + inserted items upper bound)
let new_capacity = Capacities::Array(values_data.len() + item_data.len());

let mut mutable_values =
MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity);

let mut new_offsets = vec![O::usize_as(0)];
let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());
// New offsets and top-level list validity bitmap
let mut new_offsets = Vec::with_capacity(list_array.len() + 1);
new_offsets.push(O::usize_as(0));
let mut list_valid = Vec::<bool>::with_capacity(list_array.len());

let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions
// Spark supports only Int32 position indices
let pos_data: &Int32Array = as_primitive_array(&pos_array);

for (row_index, offset_window) in offsets.windows(2).enumerate() {
let pos = pos_data.values()[row_index];
let start = offset_window[0].as_usize();
let end = offset_window[1].as_usize();
let is_item_null = items_array.is_null(row_index);
for (row_index, window) in offsets.windows(2).enumerate() {
let start = window[0].as_usize();
let end = window[1].as_usize();
let len = end - start;
let pos = pos_data.value(row_index);

if list_array.is_null(row_index) {
// In Spark if value of the array is NULL than nothing happens
mutable_values.extend_nulls(1);
new_offsets.push(new_offsets[row_index] + O::one());
new_nulls.push(false);
// Top-level list row is NULL: do not write any child values and do not advance offset
new_offsets.push(new_offsets[row_index]);
list_valid.push(false);
continue;
}

if pos == 0 {
return Err(DataFusionError::Internal(
"Position for array_insert should be greter or less than zero".to_string(),
"Position for array_insert should be greater or less than zero".to_string(),
));
}

if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) {
let corrected_pos = if pos > 0 {
(pos - 1).as_usize()
} else {
end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 }
};
let new_array_len = std::cmp::max(end - start + 1, corrected_pos);
if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
return Err(DataFusionError::Internal(format!(
"Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}"
)));
}
let final_len: usize;

if (start + corrected_pos) <= end {
mutable_values.extend(0, start, start + corrected_pos);
if pos > 0 {
// Positive index (1-based)
let pos1 = pos as usize;
if pos1 <= len + 1 {
// In-range insertion (including appending to end)
let corrected = pos1 - 1; // 0-based insertion point
mutable_values.extend(0, start, start + corrected);
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend(0, start + corrected_pos, end);
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
mutable_values.extend(0, start + corrected, end);
final_len = len + 1;
} else {
// Beyond end: pad with nulls then insert
let corrected = pos1 - 1;
let padding = corrected - len;
mutable_values.extend(0, start, end);
mutable_values.extend_nulls(new_array_len - (end - start));
mutable_values.extend_nulls(padding);
mutable_values.extend(1, row_index, row_index + 1);
// In that case spark actualy makes array longer than expected;
// For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one());
final_len = corrected + 1; // equals pos1
}
} else {
// This comment is takes from the Apache Spark source code as is:
// special case- if the new position is negative but larger than the current array size
// place the new item at start of array, place the current array contents at the end
// and fill the newly created array elements inbetween with a null
let base_offset = if legacy_mode { 1 } else { 0 };
let new_array_len = (-pos + base_offset).as_usize();
if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
return Err(DataFusionError::Internal(format!(
"Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}"
)));
}
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend_nulls(new_array_len - (end - start + 1));
mutable_values.extend(0, start, end);
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
}
if is_item_null {
if (start == end) || (values.is_null(row_index)) {
new_nulls.push(false)
// Negative index (1-based from the end)
let k = (-pos) as usize;

if k <= len {
// In-range negative insertion
// Non-legacy: -1 behaves like append to end (corrected = len - k + 1)
// Legacy: -1 behaves like insert before the last element (corrected = len - k)
let base_offset = if legacy_mode { 0 } else { 1 };
let corrected = len - k + base_offset;
mutable_values.extend(0, start, start + corrected);
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend(0, start + corrected, end);
final_len = len + 1;
} else {
new_nulls.push(true)
// Negative index beyond the start (Spark-specific behavior):
// Place item first, then pad with nulls, then append the original array.
// Final length = k + base_offset, where base_offset = 1 in legacy mode, otherwise 0.
let base_offset = if legacy_mode { 1 } else { 0 };
let target_len = k + base_offset;
let padding = target_len.saturating_sub(len + 1);
mutable_values.extend(1, row_index, row_index + 1); // insert item first
mutable_values.extend_nulls(padding); // pad nulls
mutable_values.extend(0, start, end); // append original values
final_len = target_len;
}
} else {
new_nulls.push(true)
}

if final_len > MAX_ROUNDED_ARRAY_LENGTH {
return Err(DataFusionError::Internal(format!(
"Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH}, but got {final_len}"
)));
}

let prev = new_offsets[row_index].as_usize();
new_offsets.push(O::usize_as(prev + final_len));
list_valid.push(true);
}

let data = make_array(mutable_values.freeze());
let data_type = match list_array.data_type() {
DataType::List(field) => field.data_type(),
DataType::LargeList(field) => field.data_type(),
let child = make_array(mutable_values.freeze());

// Reuse the original list element field (name/type/nullability)
let elem_field = match list_array.data_type() {
DataType::List(field) => Arc::clone(field),
DataType::LargeList(field) => Arc::clone(field),
_ => unreachable!(),
};
let new_array = GenericListArray::<O>::try_new(
Arc::new(Field::new("item", data_type.clone(), true)),

// Build the resulting list array
let new_list = GenericListArray::<O>::try_new(
elem_field,
OffsetBuffer::new(new_offsets.into()),
data,
Some(NullBuffer::new(new_nulls.into())),
child,
Some(NullBuffer::new(list_valid.into())),
)?;

Ok(ColumnarValue::Array(Arc::new(new_array)))
Ok(ColumnarValue::Array(Arc::new(new_list)))
}

impl Display for ArrayInsert {
Expand Down Expand Up @@ -442,4 +451,37 @@ mod test {

Ok(())
}

#[test]
fn test_array_insert_bug_repro_null_item_pos1_fixed() -> Result<()> {
use arrow::array::{Array, ArrayRef, Int32Array, ListArray};
use arrow::datatypes::Int32Type;

// row0 = [0, null, 0]
// row1 = [1, null, 1]
let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), None, Some(0)]),
Some(vec![Some(1), None, Some(1)]),
]);

let positions = Int32Array::from(vec![1, 1]);
let items = Int32Array::from(vec![None, None]);

let ColumnarValue::Array(result) = array_insert(
&list,
&(Arc::new(items) as ArrayRef),
&(Arc::new(positions) as ArrayRef),
false, // legacy_mode = false
)?
else {
unreachable!()
};

let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![None, Some(0), None, Some(0)]),
Some(vec![None, Some(1), None, Some(1)]),
]);
assert_eq!(&result.to_data(), &expected.to_data());
Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.ArrayType

import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus}
import org.apache.comet.DataTypeSupport.isComplexType
Expand Down Expand Up @@ -777,4 +778,28 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}
}
}

test("array_reverse 2") {
// This test validates data correctness for array<binary> columns with nullable elements.
// See https://github.com/apache/datafusion-comet/issues/2612
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val schemaOptions =
SchemaGenOptions(generateArray = true, generateStruct = false, generateMap = false)
val dataOptions = DataGenOptions(allowNull = true, generateNegativeZero = false)
ParquetGenerator.makeParquetFile(random, spark, filename, 100, schemaOptions, dataOptions)
}
withTempView("t1") {
val table = spark.read.parquet(filename)
table.createOrReplaceTempView("t1")
for (field <- table.schema.fields.filter(_.dataType.isInstanceOf[ArrayType])) {
val sql = s"SELECT ${field.name}, reverse(${field.name}) FROM t1 ORDER BY ${field.name}"
checkSparkAnswer(sql)
}
}
}
}
}
Loading