diff --git a/src/include/rfuns_extension.hpp b/src/include/rfuns_extension.hpp index 0a2e3ff03..543c79dcb 100644 --- a/src/include/rfuns_extension.hpp +++ b/src/include/rfuns_extension.hpp @@ -61,11 +61,12 @@ ScalarFunctionSet base_r_gte(); ScalarFunctionSet base_r_is_na(); ScalarFunctionSet base_r_as_integer(); +ScalarFunctionSet base_r_as_numeric(); // sum -AggregateFunctionSet base_r_sum(); -AggregateFunctionSet base_r_min(); -AggregateFunctionSet base_r_max(); +AggregateFunctionSet base_r_aggregate_sum(); +AggregateFunctionSet base_r_aggregate_min(); +AggregateFunctionSet base_r_aggregate_max(); ScalarFunctionSet binary_dispatch(ScalarFunctionSet fn) ; diff --git a/src/rfuns.cpp b/src/rfuns.cpp index e5231d95b..9ae1fe338 100644 --- a/src/rfuns.cpp +++ b/src/rfuns.cpp @@ -4,6 +4,7 @@ #include #include +#include namespace duckdb { namespace rfuns { @@ -29,7 +30,7 @@ void BaseRAddFunctionDouble(DataChunk &args, ExpressionState &state, Vector &res BinaryExecutor::ExecuteWithNulls( parts.lefts, parts.rights, result, args.size(), [&](double left, double right, ValidityMask &mask, idx_t idx) { - if (isnan(left) || isnan(right)) { + if (std::isnan(left) || std::isnan(right)) { mask.SetInvalid(idx); return 0.0; } @@ -38,7 +39,7 @@ void BaseRAddFunctionDouble(DataChunk &args, ExpressionState &state, Vector &res } double ExecuteBaseRPlusFunctionIntDouble(int32_t left, double right, ValidityMask &mask, idx_t idx) { - if (isnan(right)) { + if (std::isnan(right)) { mask.SetInvalid(idx); return 0.0; } @@ -86,6 +87,7 @@ ScalarFunctionSet base_r_add() { #include #include #include +#include namespace duckdb { namespace rfuns { @@ -93,7 +95,7 @@ namespace rfuns { namespace { template -int32_t check_range(T value, ValidityMask &mask, idx_t idx) { +int32_t check_int_range(T value, ValidityMask &mask, idx_t idx) { if (value > std::numeric_limits::max() || value < std::numeric_limits::min() ) { mask.SetInvalid(idx); } @@ -101,66 +103,90 @@ int32_t check_range(T value, ValidityMask &mask, idx_t idx) { return static_cast(value); } -template -int32_t cast(T input, ValidityMask &mask, idx_t idx) { - return static_cast(input); +template +TO cast(FROM input, ValidityMask &mask, idx_t idx) { + return static_cast(input); } template <> -int32_t cast(double input, ValidityMask &mask, idx_t idx) { - if (isnan(input)) { +int32_t cast(double input, ValidityMask &mask, idx_t idx) { + if (std::isnan(input)) { mask.SetInvalid(idx); } - return check_range(input, mask, idx); + return check_int_range(input, mask, idx); } template <> -int32_t cast(string_t input, ValidityMask &mask, idx_t idx) { +double cast(string_t input, ValidityMask &mask, idx_t idx) { double result; if (!TryDoubleCast(input.GetData(), input.GetSize(), result, false)) { mask.SetInvalid(idx); } - return cast(result, mask, idx); + return result; +} + +template <> +int32_t cast(string_t input, ValidityMask &mask, idx_t idx) { + auto dbl = cast(input, mask, idx); + return cast(dbl, mask, idx); +} + +template <> +int32_t cast(date_t input, ValidityMask &mask, idx_t idx) { + return input.days; } template <> -int32_t cast(date_t input, ValidityMask &mask, idx_t idx) { +double cast(date_t input, ValidityMask &mask, idx_t idx) { return input.days; } template <> -int32_t cast(timestamp_t input, ValidityMask &mask, idx_t idx) { - return check_range(Timestamp::GetEpochSeconds(input), mask, idx); +int32_t cast(timestamp_t input, ValidityMask &mask, idx_t idx) { + return check_int_range(Timestamp::GetEpochSeconds(input), mask, idx); } -template -ScalarFunction AsIntegerFunction() { +template <> +double cast(timestamp_t input, ValidityMask &mask, idx_t idx) { + return check_int_range(Timestamp::GetEpochSeconds(input), mask, idx); +} + +template +ScalarFunction AsNumberFunction() { using physical_type = typename physical::type; + using result_type = typename physical::type; auto fun = [](DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteWithNulls( - args.data[0], result, args.size(), cast + UnaryExecutor::ExecuteWithNulls( + args.data[0], result, args.size(), cast ); }; - return ScalarFunction({TYPE}, LogicalType::INTEGER, fun); + return ScalarFunction({TYPE}, RESULT_TYPE, fun); } -} +template +ScalarFunctionSet as_number(std::string name) { + ScalarFunctionSet set(name); -ScalarFunctionSet base_r_as_integer() { - ScalarFunctionSet set("r_base::as.integer"); + set.AddFunction(AsNumberFunction()); + set.AddFunction(AsNumberFunction()); + set.AddFunction(AsNumberFunction()); + set.AddFunction(AsNumberFunction()); + set.AddFunction(AsNumberFunction()); + set.AddFunction(AsNumberFunction()); - set.AddFunction(AsIntegerFunction()); - set.AddFunction(AsIntegerFunction()); - set.AddFunction(AsIntegerFunction()); + return set; +} - set.AddFunction(AsIntegerFunction()); +} - set.AddFunction(AsIntegerFunction()); - set.AddFunction(AsIntegerFunction()); +ScalarFunctionSet base_r_as_integer() { + return as_number("r_base::as.integer"); +} - return set; +ScalarFunctionSet base_r_as_numeric() { + return as_number("r_base::as.numeric"); } } @@ -206,19 +232,12 @@ ScalarFunctionSet binary_dispatch(ScalarFunctionSet fn) { #include #include #include +#include namespace duckdb { namespace rfuns { -void isna_double(DataChunk &args, ExpressionState &state, Vector &result) { - auto count = args.size(); - auto input = args.data[0]; - auto mask = FlatVector::Validity(input); - auto* data = FlatVector::GetData(input); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - +void isna_double_loop(idx_t count, const double* data, bool* result_data, ValidityMask mask) { idx_t base_idx = 0; auto entry_count = ValidityMask::EntryCount(count); for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { @@ -250,14 +269,52 @@ void isna_double(DataChunk &args, ExpressionState &state, Vector &result) { } } -void isna_any(DataChunk &args, ExpressionState &state, Vector &result) { +void isna_double(DataChunk &args, ExpressionState &state, Vector &result) { auto count = args.size(); auto input = args.data[0]; - auto mask = FlatVector::Validity(input); - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); + switch(input.GetVectorType()) { + case VectorType::FLAT_VECTOR: { + result.SetVectorType(VectorType::FLAT_VECTOR); + + isna_double_loop( + count, + FlatVector::GetData(input), + FlatVector::GetData(result), + FlatVector::Validity(input) + ); + + break; + } + case VectorType::CONSTANT_VECTOR: { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + auto result_data = ConstantVector::GetData(result); + auto ldata = ConstantVector::GetData(input); + + *result_data = ConstantVector::IsNull(input) || isnan(*ldata); + + break; + } + + default: { + UnifiedVectorFormat vdata; + input.ToUnifiedFormat(count, vdata); + result.SetVectorType(VectorType::FLAT_VECTOR); + + isna_double_loop( + count, + UnifiedVectorFormat::GetData(vdata), + FlatVector::GetData(result), + vdata.validity + ); + + break; + } + } +} + +void isna_any_loop(idx_t count, bool* result_data, ValidityMask mask) { if (mask.AllValid()) { for (idx_t i = 0; i < count; i++) { result_data[i] = false; @@ -289,6 +346,47 @@ void isna_any(DataChunk &args, ExpressionState &state, Vector &result) { } } } + +} + +void isna_any(DataChunk &args, ExpressionState &state, Vector &result) { + auto count = args.size(); + auto input = args.data[0]; + + switch(input.GetVectorType()) { + case VectorType::FLAT_VECTOR: { + result.SetVectorType(VectorType::FLAT_VECTOR); + isna_any_loop( + count, + FlatVector::GetData(result), + FlatVector::Validity(input) + ); + + break; + } + + case VectorType::CONSTANT_VECTOR: { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + auto result_data = ConstantVector::GetData(result); + *result_data = ConstantVector::IsNull(input); + + break; + } + + default : { + UnifiedVectorFormat vdata; + input.ToUnifiedFormat(count, vdata); + result.SetVectorType(VectorType::FLAT_VECTOR); + isna_any_loop( + count, + FlatVector::GetData(result), + vdata.validity + ); + + break; + } + } + } @@ -447,12 +545,12 @@ AggregateFunctionSet base_r_minmax(std::string name) { return set; } -AggregateFunctionSet base_r_min() { - return base_r_minmax("r_base::min"); +AggregateFunctionSet base_r_aggregate_min() { + return base_r_minmax("r_base::aggregate::min"); } -AggregateFunctionSet base_r_max() { - return base_r_minmax("r_base::max"); +AggregateFunctionSet base_r_aggregate_max() { + return base_r_minmax("r_base::aggregate::max"); } @@ -464,6 +562,7 @@ AggregateFunctionSet base_r_max() { #include #include #include +#include namespace duckdb { namespace rfuns { @@ -627,7 +726,7 @@ bool set_null(T value, ValidityMask &mask, idx_t idx) { template <> bool set_null(double value, ValidityMask &mask, idx_t idx) { - if (isnan(value)) { + if (std::isnan(value)) { mask.SetInvalid(idx); return true; } @@ -772,10 +871,11 @@ static void register_rfuns(DatabaseInstance &instance) { ExtensionUtil::RegisterFunction(instance, base_r_is_na()); ExtensionUtil::RegisterFunction(instance, base_r_as_integer()); + ExtensionUtil::RegisterFunction(instance, base_r_as_numeric()); - ExtensionUtil::RegisterFunction(instance, base_r_sum()); - ExtensionUtil::RegisterFunction(instance, base_r_min()); - ExtensionUtil::RegisterFunction(instance, base_r_max()); + ExtensionUtil::RegisterFunction(instance, base_r_aggregate_sum()); + ExtensionUtil::RegisterFunction(instance, base_r_aggregate_min()); + ExtensionUtil::RegisterFunction(instance, base_r_aggregate_max()); } } // namespace rfuns @@ -926,8 +1026,8 @@ void add_RSum(AggregateFunctionSet& set, const LogicalType& type) { )); } -AggregateFunctionSet base_r_sum() { - AggregateFunctionSet set("r_base::sum"); +AggregateFunctionSet base_r_aggregate_sum() { + AggregateFunctionSet set("r_base::aggregate::sum"); add_RSum(set, LogicalType::BOOLEAN); add_RSum(set, LogicalType::INTEGER);