Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions src/is_na.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace duckdb {
namespace rfuns {

void isna_double_loop(idx_t count, const double* data, bool* result_data, ValidityMask mask) {
void isna_double_loop_flat(idx_t count, const double* data, bool* result_data, ValidityMask mask) {
idx_t base_idx = 0;
auto entry_count = ValidityMask::EntryCount(count);
for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) {
Expand Down Expand Up @@ -41,6 +41,24 @@ void isna_double_loop(idx_t count, const double* data, bool* result_data, Validi
}
}

void isna_double_loop_unified(idx_t count, const double* data, bool* result_data, const SelectionVector* sel, ValidityMask mask) {
if (!mask.AllValid()) {
for (idx_t i = 0; i < count; i++) {
auto idx = sel->get_index(i);
if (mask.RowIsValidUnsafe(idx)) {
result_data[i] = std::isnan(data[idx]);
} else {
result_data[i] = true;
}
}
} else {
for (idx_t i = 0; i < count; i++) {
auto idx = sel->get_index(i);
result_data[i] = std::isnan(data[idx]);
}
}
}

void isna_double(DataChunk &args, ExpressionState &state, Vector &result) {
auto count = args.size();
auto input = args.data[0];
Expand All @@ -49,7 +67,7 @@ void isna_double(DataChunk &args, ExpressionState &state, Vector &result) {
case VectorType::FLAT_VECTOR: {
result.SetVectorType(VectorType::FLAT_VECTOR);

isna_double_loop(
isna_double_loop_flat(
count,
FlatVector::GetData<double>(input),
FlatVector::GetData<bool>(result),
Expand All @@ -60,7 +78,7 @@ void isna_double(DataChunk &args, ExpressionState &state, Vector &result) {
}

case VectorType::CONSTANT_VECTOR: {
result.SetVectorType(VectorType::CONSTANT_VECTOR);
result.SetVectorType(VectorType::FLAT_VECTOR);
auto result_data = ConstantVector::GetData<bool>(result);
auto ldata = ConstantVector::GetData<double>(input);

Expand All @@ -74,10 +92,11 @@ void isna_double(DataChunk &args, ExpressionState &state, Vector &result) {
input.ToUnifiedFormat(count, vdata);
result.SetVectorType(VectorType::FLAT_VECTOR);

isna_double_loop(
isna_double_loop_unified(
count,
UnifiedVectorFormat::GetData<double>(vdata),
FlatVector::GetData<bool>(result),
vdata.sel,
vdata.validity
);

Expand All @@ -86,7 +105,7 @@ void isna_double(DataChunk &args, ExpressionState &state, Vector &result) {
}
}

void isna_any_loop(idx_t count, bool* result_data, ValidityMask mask) {
void isna_any_loop_flat(idx_t count, bool* result_data, ValidityMask mask) {
if (mask.AllValid()) {
for (idx_t i = 0; i < count; i++) {
result_data[i] = false;
Expand Down Expand Up @@ -121,14 +140,31 @@ void isna_any_loop(idx_t count, bool* result_data, ValidityMask mask) {

}

void isna_any_loop_unified(idx_t count, bool* result_data, const SelectionVector* sel, ValidityMask mask) {
if (!mask.AllValid()) {
for (idx_t i = 0; i < count; i++) {
auto idx = sel->get_index(i);
if (mask.RowIsValidUnsafe(idx)) {
result_data[i] = false;
} else {
result_data[i] = true;
}
}
} else {
for (idx_t i = 0; i < count; i++) {
result_data[i] = true;
}
}
}

void isna_any(DataChunk &args, ExpressionState &state, Vector &result) {
auto count = args.size();
auto input = args.data[0];

switch(input.GetVectorType()) {
case VectorType::FLAT_VECTOR: {
result.SetVectorType(VectorType::FLAT_VECTOR);
isna_any_loop(
isna_any_loop_flat(
count,
FlatVector::GetData<bool>(result),
FlatVector::Validity(input)
Expand All @@ -149,9 +185,10 @@ void isna_any(DataChunk &args, ExpressionState &state, Vector &result) {
UnifiedVectorFormat vdata;
input.ToUnifiedFormat(count, vdata);
result.SetVectorType(VectorType::FLAT_VECTOR);
isna_any_loop(
isna_any_loop_unified(
count,
FlatVector::GetData<bool>(result),
vdata.sel,
vdata.validity
);

Expand All @@ -165,8 +202,13 @@ void isna_any(DataChunk &args, ExpressionState &state, Vector &result) {
ScalarFunctionSet base_r_is_na() {
ScalarFunctionSet set("r_base::is.na");

set.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, isna_double));
set.AddFunction(ScalarFunction({LogicalType::ANY} , LogicalType::BOOLEAN, isna_any));
ScalarFunction is_na_double({LogicalType::DOUBLE}, LogicalType::BOOLEAN, isna_double);
is_na_double.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
set.AddFunction(is_na_double);

ScalarFunction is_na_any({LogicalType::ANY} , LogicalType::BOOLEAN, isna_any);
is_na_any.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
set.AddFunction(is_na_any);

return set;
}
Expand Down
Loading