Skip to content

Commit 968826e

Browse files
authored
Merge pull request rapidsai#18619 from rapidsai/branch-25.06
Forward-merge branch-25.06 into branch-25.08
2 parents 7bf6674 + b5dc95e commit 968826e

9 files changed

+428
-108
lines changed

cpp/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,10 @@ add_library(
772772
src/text/wordpiece_tokenize.cu
773773
src/transform/bools_to_mask.cu
774774
src/transform/compute_column.cu
775+
src/transform/compute_column_kernel_complex.cu
776+
src/transform/compute_column_kernel_null_complex.cu
777+
src/transform/compute_column_kernel_null_primitive.cu
778+
src/transform/compute_column_kernel_primitive.cu
775779
src/transform/encode.cu
776780
src/transform/mask_to_bools.cu
777781
src/transform/nans_to_nulls.cu

cpp/include/cudf/ast/detail/expression_evaluator.cuh

Lines changed: 129 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,27 @@
2929

3030
#include <rmm/cuda_stream_view.hpp>
3131

32-
namespace cudf {
32+
#include <cuda/std/type_traits>
33+
#include <cuda/std/utility>
3334

34-
namespace ast {
35+
namespace cudf::ast::detail {
3536

36-
namespace detail {
37+
CUDF_HOST_DEVICE constexpr bool is_complex_type(cudf::type_id type)
38+
{
39+
return type == cudf::type_id::DECIMAL32 || type == cudf::type_id::DECIMAL64 ||
40+
type == cudf::type_id::DECIMAL128 || type == cudf::type_id::STRING;
41+
}
42+
43+
/**
44+
* @brief Maps void for string and decimal types
45+
*
46+
* @tparam t The cudf::type_id to map to a C++ type or void
47+
*/
48+
template <cudf::type_id t>
49+
struct dispatch_void_if_complex {
50+
// Default to void for non-primitive types
51+
using type = cuda::std::conditional_t<is_complex_type(t), void, id_to_type<t>>;
52+
};
3753

3854
/**
3955
* @brief A container for capturing the output of an evaluated expression.
@@ -97,7 +113,7 @@ struct value_expression_result
97113
__device__ inline void set_value(cudf::size_type index,
98114
possibly_null_value_t<Element, has_nulls> const& result)
99115
{
100-
if constexpr (std::is_same_v<Element, T>) {
116+
if constexpr (cuda::std::is_same_v<Element, T>) {
101117
_obj = result;
102118
} else {
103119
CUDF_UNREACHABLE("Output type does not match container type.");
@@ -214,7 +230,7 @@ struct single_dispatch_binary_operator {
214230
template <typename LHS, typename F, typename... Ts>
215231
__device__ inline auto operator()(F&& f, Ts&&... args)
216232
{
217-
f.template operator()<LHS, LHS>(std::forward<Ts>(args)...);
233+
f.template operator()<LHS, LHS>(cuda::std::forward<Ts>(args)...);
218234
}
219235
};
220236

@@ -224,7 +240,7 @@ struct single_dispatch_binary_operator {
224240
* This class is designed for n-ary transform evaluation. It operates on two
225241
* tables.
226242
*/
227-
template <bool has_nulls>
243+
template <bool has_nulls, bool has_complex_type = true>
228244
struct expression_evaluator {
229245
public:
230246
/**
@@ -345,7 +361,11 @@ struct expression_evaluator {
345361
* @param output_row_index The row in the output to insert the result.
346362
* @param op The operator to act with.
347363
*/
348-
template <typename Input, typename ResultSubclass, typename T, bool result_has_nulls>
364+
template <typename Input,
365+
typename ResultSubclass,
366+
typename T,
367+
bool result_has_nulls,
368+
CUDF_ENABLE_IF(!cuda::std::is_void_v<Input>)>
349369
__device__ inline void operator()(
350370
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
351371
cudf::size_type const input_row_index,
@@ -366,6 +386,23 @@ struct expression_evaluator {
366386
thread_intermediate_storage);
367387
}
368388

389+
template <typename Input,
390+
typename ResultSubclass,
391+
typename T,
392+
bool result_has_nulls,
393+
CUDF_ENABLE_IF(cuda::std::is_void_v<Input>)>
394+
__device__ inline void operator()(
395+
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
396+
cudf::size_type const input_row_index,
397+
detail::device_data_reference const& input,
398+
detail::device_data_reference const& output,
399+
cudf::size_type const output_row_index,
400+
ast_operator const op,
401+
IntermediateDataType<has_nulls>* thread_intermediate_storage) const
402+
{
403+
CUDF_UNREACHABLE("Unsupported type in operator().");
404+
}
405+
369406
/**
370407
* @brief Callable to perform a binary operation.
371408
*
@@ -382,7 +419,12 @@ struct expression_evaluator {
382419
* @param output_row_index The row in the output to insert the result.
383420
* @param op The operator to act with.
384421
*/
385-
template <typename LHS, typename RHS, typename ResultSubclass, typename T, bool result_has_nulls>
422+
template <typename LHS,
423+
typename RHS,
424+
typename ResultSubclass,
425+
typename T,
426+
bool result_has_nulls,
427+
CUDF_ENABLE_IF(!cuda::std::is_void_v<LHS> && !cuda::std::is_void_v<RHS>)>
386428
__device__ inline void operator()(
387429
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
388430
cudf::size_type const left_row_index,
@@ -408,6 +450,26 @@ struct expression_evaluator {
408450
thread_intermediate_storage);
409451
}
410452

453+
template <typename LHS,
454+
typename RHS,
455+
typename ResultSubclass,
456+
typename T,
457+
bool result_has_nulls,
458+
CUDF_ENABLE_IF(cuda::std::is_void_v<LHS> || cuda::std::is_void_v<RHS>)>
459+
__device__ inline void operator()(
460+
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
461+
cudf::size_type const left_row_index,
462+
cudf::size_type const right_row_index,
463+
detail::device_data_reference const& lhs,
464+
detail::device_data_reference const& rhs,
465+
detail::device_data_reference const& output,
466+
cudf::size_type const output_row_index,
467+
ast_operator const op,
468+
IntermediateDataType<has_nulls>* thread_intermediate_storage) const
469+
{
470+
CUDF_UNREACHABLE("Unsupported type in operator().");
471+
}
472+
411473
/**
412474
* @brief Evaluate an expression applied to a row.
413475
*
@@ -461,15 +523,27 @@ struct expression_evaluator {
461523
plan.data_references[plan.operator_source_indices[operator_source_index++]];
462524
auto input_row_index =
463525
input.table_source == table_reference::LEFT ? left_row_index : right_row_index;
464-
type_dispatcher(input.data_type,
465-
*this,
466-
output_object,
467-
input_row_index,
468-
input,
469-
output,
470-
output_row_index,
471-
op,
472-
thread_intermediate_storage);
526+
if constexpr (has_complex_type) {
527+
type_dispatcher(input.data_type,
528+
*this,
529+
output_object,
530+
input_row_index,
531+
input,
532+
output,
533+
output_row_index,
534+
op,
535+
thread_intermediate_storage);
536+
} else {
537+
type_dispatcher<dispatch_void_if_complex>(input.data_type,
538+
*this,
539+
output_object,
540+
input_row_index,
541+
input,
542+
output,
543+
output_row_index,
544+
op,
545+
thread_intermediate_storage);
546+
}
473547
} else if (arity == 2) {
474548
// Binary operator
475549
auto const& lhs =
@@ -478,20 +552,33 @@ struct expression_evaluator {
478552
plan.data_references[plan.operator_source_indices[operator_source_index++]];
479553
auto const& output =
480554
plan.data_references[plan.operator_source_indices[operator_source_index++]];
481-
type_dispatcher(lhs.data_type,
482-
detail::single_dispatch_binary_operator{},
483-
*this,
484-
output_object,
485-
left_row_index,
486-
right_row_index,
487-
lhs,
488-
rhs,
489-
output,
490-
output_row_index,
491-
op,
492-
thread_intermediate_storage);
493-
} else {
494-
CUDF_UNREACHABLE("Invalid operator arity.");
555+
if constexpr (has_complex_type) {
556+
type_dispatcher(lhs.data_type,
557+
detail::single_dispatch_binary_operator{},
558+
*this,
559+
output_object,
560+
left_row_index,
561+
right_row_index,
562+
lhs,
563+
rhs,
564+
output,
565+
output_row_index,
566+
op,
567+
thread_intermediate_storage);
568+
} else {
569+
type_dispatcher<dispatch_void_if_complex>(lhs.data_type,
570+
detail::single_dispatch_binary_operator{},
571+
*this,
572+
output_object,
573+
left_row_index,
574+
right_row_index,
575+
lhs,
576+
rhs,
577+
output,
578+
output_row_index,
579+
op,
580+
thread_intermediate_storage);
581+
}
495582
}
496583
}
497584
}
@@ -589,9 +676,8 @@ struct expression_evaluator {
589676
typename ResultSubclass,
590677
typename T,
591678
bool result_has_nulls,
592-
std::enable_if_t<
593-
detail::is_valid_unary_op<detail::operator_functor<op, has_nulls>,
594-
possibly_null_value_t<Input, has_nulls>>>* = nullptr>
679+
CUDF_ENABLE_IF(detail::is_valid_unary_op<detail::operator_functor<op, has_nulls>,
680+
possibly_null_value_t<Input, has_nulls>>)>
595681
__device__ inline void operator()(
596682
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
597683
cudf::size_type const output_row_index,
@@ -613,9 +699,8 @@ struct expression_evaluator {
613699
typename ResultSubclass,
614700
typename T,
615701
bool result_has_nulls,
616-
std::enable_if_t<
617-
!detail::is_valid_unary_op<detail::operator_functor<op, has_nulls>,
618-
possibly_null_value_t<Input, has_nulls>>>* = nullptr>
702+
CUDF_ENABLE_IF(!detail::is_valid_unary_op<detail::operator_functor<op, has_nulls>,
703+
possibly_null_value_t<Input, has_nulls>>)>
619704
__device__ inline void operator()(
620705
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
621706
cudf::size_type const output_row_index,
@@ -653,10 +738,9 @@ struct expression_evaluator {
653738
typename ResultSubclass,
654739
typename T,
655740
bool result_has_nulls,
656-
std::enable_if_t<detail::is_valid_binary_op<detail::operator_functor<op, has_nulls>,
657-
possibly_null_value_t<LHS, has_nulls>,
658-
possibly_null_value_t<RHS, has_nulls>>>* =
659-
nullptr>
741+
CUDF_ENABLE_IF(detail::is_valid_binary_op<detail::operator_functor<op, has_nulls>,
742+
possibly_null_value_t<LHS, has_nulls>,
743+
possibly_null_value_t<RHS, has_nulls>>)>
660744
__device__ inline void operator()(
661745
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
662746
cudf::size_type const output_row_index,
@@ -679,10 +763,9 @@ struct expression_evaluator {
679763
typename ResultSubclass,
680764
typename T,
681765
bool result_has_nulls,
682-
std::enable_if_t<
683-
!detail::is_valid_binary_op<detail::operator_functor<op, has_nulls>,
684-
possibly_null_value_t<LHS, has_nulls>,
685-
possibly_null_value_t<RHS, has_nulls>>>* = nullptr>
766+
CUDF_ENABLE_IF(!detail::is_valid_binary_op<detail::operator_functor<op, has_nulls>,
767+
possibly_null_value_t<LHS, has_nulls>,
768+
possibly_null_value_t<RHS, has_nulls>>)>
686769
__device__ inline void operator()(
687770
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
688771
cudf::size_type const output_row_index,
@@ -701,8 +784,4 @@ struct expression_evaluator {
701784
plan; ///< The container of device data representing the expression to evaluate.
702785
};
703786

704-
} // namespace detail
705-
706-
} // namespace ast
707-
708-
} // namespace cudf
787+
} // namespace cudf::ast::detail

0 commit comments

Comments
 (0)