diff --git a/src/is_na.cpp b/src/is_na.cpp index 74a9f6e..a976d39 100644 --- a/src/is_na.cpp +++ b/src/is_na.cpp @@ -9,7 +9,7 @@ namespace duckdb { namespace rfuns { -void isna_double_loop(idx_t count, const double* data, bool* result_data, ValidityMask mask) { +void isna_double_loop_flat(idx_t count, const double* data, bool* result_data, ValidityMask mask) { idx_t base_idx = 0; auto entry_count = ValidityMask::EntryCount(count); for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { @@ -41,6 +41,24 @@ void isna_double_loop(idx_t count, const double* data, bool* result_data, Validi } } +void isna_double_loop_unified(idx_t count, const double* data, bool* result_data, const SelectionVector* sel, ValidityMask mask) { + if (!mask.AllValid()) { + for (idx_t i = 0; i < count; i++) { + auto idx = sel->get_index(i); + if (mask.RowIsValidUnsafe(idx)) { + result_data[i] = std::isnan(data[idx]); + } else { + result_data[i] = true; + } + } + } else { + for (idx_t i = 0; i < count; i++) { + auto idx = sel->get_index(i); + result_data[i] = std::isnan(data[idx]); + } + } +} + void isna_double(DataChunk &args, ExpressionState &state, Vector &result) { auto count = args.size(); auto input = args.data[0]; @@ -49,7 +67,7 @@ void isna_double(DataChunk &args, ExpressionState &state, Vector &result) { case VectorType::FLAT_VECTOR: { result.SetVectorType(VectorType::FLAT_VECTOR); - isna_double_loop( + isna_double_loop_flat( count, FlatVector::GetData(input), FlatVector::GetData(result), @@ -60,7 +78,7 @@ void isna_double(DataChunk &args, ExpressionState &state, Vector &result) { } case VectorType::CONSTANT_VECTOR: { - result.SetVectorType(VectorType::CONSTANT_VECTOR); + result.SetVectorType(VectorType::FLAT_VECTOR); auto result_data = ConstantVector::GetData(result); auto ldata = ConstantVector::GetData(input); @@ -74,10 +92,11 @@ void isna_double(DataChunk &args, ExpressionState &state, Vector &result) { input.ToUnifiedFormat(count, vdata); result.SetVectorType(VectorType::FLAT_VECTOR); - isna_double_loop( + isna_double_loop_unified( count, UnifiedVectorFormat::GetData(vdata), FlatVector::GetData(result), + vdata.sel, vdata.validity ); @@ -86,7 +105,7 @@ void isna_double(DataChunk &args, ExpressionState &state, Vector &result) { } } -void isna_any_loop(idx_t count, bool* result_data, ValidityMask mask) { +void isna_any_loop_flat(idx_t count, bool* result_data, ValidityMask mask) { if (mask.AllValid()) { for (idx_t i = 0; i < count; i++) { result_data[i] = false; @@ -121,6 +140,23 @@ void isna_any_loop(idx_t count, bool* result_data, ValidityMask mask) { } +void isna_any_loop_unified(idx_t count, bool* result_data, const SelectionVector* sel, ValidityMask mask) { + if (!mask.AllValid()) { + for (idx_t i = 0; i < count; i++) { + auto idx = sel->get_index(i); + if (mask.RowIsValidUnsafe(idx)) { + result_data[i] = false; + } else { + result_data[i] = true; + } + } + } else { + for (idx_t i = 0; i < count; i++) { + result_data[i] = true; + } + } +} + void isna_any(DataChunk &args, ExpressionState &state, Vector &result) { auto count = args.size(); auto input = args.data[0]; @@ -128,7 +164,7 @@ void isna_any(DataChunk &args, ExpressionState &state, Vector &result) { switch(input.GetVectorType()) { case VectorType::FLAT_VECTOR: { result.SetVectorType(VectorType::FLAT_VECTOR); - isna_any_loop( + isna_any_loop_flat( count, FlatVector::GetData(result), FlatVector::Validity(input) @@ -149,9 +185,10 @@ void isna_any(DataChunk &args, ExpressionState &state, Vector &result) { UnifiedVectorFormat vdata; input.ToUnifiedFormat(count, vdata); result.SetVectorType(VectorType::FLAT_VECTOR); - isna_any_loop( + isna_any_loop_unified( count, FlatVector::GetData(result), + vdata.sel, vdata.validity ); @@ -165,8 +202,13 @@ void isna_any(DataChunk &args, ExpressionState &state, Vector &result) { ScalarFunctionSet base_r_is_na() { ScalarFunctionSet set("r_base::is.na"); - set.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, isna_double)); - set.AddFunction(ScalarFunction({LogicalType::ANY} , LogicalType::BOOLEAN, isna_any)); + ScalarFunction is_na_double({LogicalType::DOUBLE}, LogicalType::BOOLEAN, isna_double); + is_na_double.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + set.AddFunction(is_na_double); + + ScalarFunction is_na_any({LogicalType::ANY} , LogicalType::BOOLEAN, isna_any); + is_na_any.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + set.AddFunction(is_na_any); return set; }