2929
3030#include < rmm/cuda_stream_view.hpp>
3131
32- namespace cudf {
32+ #include < cuda/std/type_traits>
33+ #include < cuda/std/utility>
3334
34- namespace ast {
35+ namespace cudf :: ast::detail {
3536
36- namespace detail {
37+ CUDF_HOST_DEVICE constexpr bool is_complex_type (cudf::type_id type)
38+ {
39+ return type == cudf::type_id::DECIMAL32 || type == cudf::type_id::DECIMAL64 ||
40+ type == cudf::type_id::DECIMAL128 || type == cudf::type_id::STRING;
41+ }
42+
43+ /* *
44+ * @brief Maps void for string and decimal types
45+ *
46+ * @tparam t The cudf::type_id to map to a C++ type or void
47+ */
48+ template <cudf::type_id t>
49+ struct dispatch_void_if_complex {
50+ // Default to void for non-primitive types
51+ using type = cuda::std::conditional_t <is_complex_type(t), void , id_to_type<t>>;
52+ };
3753
3854/* *
3955 * @brief A container for capturing the output of an evaluated expression.
@@ -97,7 +113,7 @@ struct value_expression_result
97113 __device__ inline void set_value (cudf::size_type index,
98114 possibly_null_value_t <Element, has_nulls> const & result)
99115 {
100- if constexpr (std::is_same_v<Element, T>) {
116+ if constexpr (cuda:: std::is_same_v<Element, T>) {
101117 _obj = result;
102118 } else {
103119 CUDF_UNREACHABLE (" Output type does not match container type." );
@@ -214,7 +230,7 @@ struct single_dispatch_binary_operator {
214230 template <typename LHS, typename F, typename ... Ts>
215231 __device__ inline auto operator ()(F&& f, Ts&&... args)
216232 {
217- f.template operator ()<LHS, LHS>(std::forward<Ts>(args)...);
233+ f.template operator ()<LHS, LHS>(cuda:: std::forward<Ts>(args)...);
218234 }
219235};
220236
@@ -224,7 +240,7 @@ struct single_dispatch_binary_operator {
224240 * This class is designed for n-ary transform evaluation. It operates on two
225241 * tables.
226242 */
227- template <bool has_nulls>
243+ template <bool has_nulls, bool has_complex_type = true >
228244struct expression_evaluator {
229245 public:
230246 /* *
@@ -345,7 +361,11 @@ struct expression_evaluator {
345361 * @param output_row_index The row in the output to insert the result.
346362 * @param op The operator to act with.
347363 */
348- template <typename Input, typename ResultSubclass, typename T, bool result_has_nulls>
364+ template <typename Input,
365+ typename ResultSubclass,
366+ typename T,
367+ bool result_has_nulls,
368+ CUDF_ENABLE_IF (!cuda::std::is_void_v<Input>)>
349369 __device__ inline void operator ()(
350370 expression_result<ResultSubclass, T, result_has_nulls>& output_object,
351371 cudf::size_type const input_row_index,
@@ -366,6 +386,23 @@ struct expression_evaluator {
366386 thread_intermediate_storage);
367387 }
368388
389+ template <typename Input,
390+ typename ResultSubclass,
391+ typename T,
392+ bool result_has_nulls,
393+ CUDF_ENABLE_IF (cuda::std::is_void_v<Input>)>
394+ __device__ inline void operator ()(
395+ expression_result<ResultSubclass, T, result_has_nulls>& output_object,
396+ cudf::size_type const input_row_index,
397+ detail::device_data_reference const & input,
398+ detail::device_data_reference const & output,
399+ cudf::size_type const output_row_index,
400+ ast_operator const op,
401+ IntermediateDataType<has_nulls>* thread_intermediate_storage) const
402+ {
403+ CUDF_UNREACHABLE (" Unsupported type in operator()." );
404+ }
405+
369406 /* *
370407 * @brief Callable to perform a binary operation.
371408 *
@@ -382,7 +419,12 @@ struct expression_evaluator {
382419 * @param output_row_index The row in the output to insert the result.
383420 * @param op The operator to act with.
384421 */
385- template <typename LHS, typename RHS, typename ResultSubclass, typename T, bool result_has_nulls>
422+ template <typename LHS,
423+ typename RHS,
424+ typename ResultSubclass,
425+ typename T,
426+ bool result_has_nulls,
427+ CUDF_ENABLE_IF (!cuda::std::is_void_v<LHS> && !cuda::std::is_void_v<RHS>)>
386428 __device__ inline void operator ()(
387429 expression_result<ResultSubclass, T, result_has_nulls>& output_object,
388430 cudf::size_type const left_row_index,
@@ -408,6 +450,26 @@ struct expression_evaluator {
408450 thread_intermediate_storage);
409451 }
410452
453+ template <typename LHS,
454+ typename RHS,
455+ typename ResultSubclass,
456+ typename T,
457+ bool result_has_nulls,
458+ CUDF_ENABLE_IF (cuda::std::is_void_v<LHS> || cuda::std::is_void_v<RHS>)>
459+ __device__ inline void operator ()(
460+ expression_result<ResultSubclass, T, result_has_nulls>& output_object,
461+ cudf::size_type const left_row_index,
462+ cudf::size_type const right_row_index,
463+ detail::device_data_reference const & lhs,
464+ detail::device_data_reference const & rhs,
465+ detail::device_data_reference const & output,
466+ cudf::size_type const output_row_index,
467+ ast_operator const op,
468+ IntermediateDataType<has_nulls>* thread_intermediate_storage) const
469+ {
470+ CUDF_UNREACHABLE (" Unsupported type in operator()." );
471+ }
472+
411473 /* *
412474 * @brief Evaluate an expression applied to a row.
413475 *
@@ -461,15 +523,27 @@ struct expression_evaluator {
461523 plan.data_references [plan.operator_source_indices [operator_source_index++]];
462524 auto input_row_index =
463525 input.table_source == table_reference::LEFT ? left_row_index : right_row_index;
464- type_dispatcher (input.data_type ,
465- *this ,
466- output_object,
467- input_row_index,
468- input,
469- output,
470- output_row_index,
471- op,
472- thread_intermediate_storage);
526+ if constexpr (has_complex_type) {
527+ type_dispatcher (input.data_type ,
528+ *this ,
529+ output_object,
530+ input_row_index,
531+ input,
532+ output,
533+ output_row_index,
534+ op,
535+ thread_intermediate_storage);
536+ } else {
537+ type_dispatcher<dispatch_void_if_complex>(input.data_type ,
538+ *this ,
539+ output_object,
540+ input_row_index,
541+ input,
542+ output,
543+ output_row_index,
544+ op,
545+ thread_intermediate_storage);
546+ }
473547 } else if (arity == 2 ) {
474548 // Binary operator
475549 auto const & lhs =
@@ -478,20 +552,33 @@ struct expression_evaluator {
478552 plan.data_references [plan.operator_source_indices [operator_source_index++]];
479553 auto const & output =
480554 plan.data_references [plan.operator_source_indices [operator_source_index++]];
481- type_dispatcher (lhs.data_type ,
482- detail::single_dispatch_binary_operator{},
483- *this ,
484- output_object,
485- left_row_index,
486- right_row_index,
487- lhs,
488- rhs,
489- output,
490- output_row_index,
491- op,
492- thread_intermediate_storage);
493- } else {
494- CUDF_UNREACHABLE (" Invalid operator arity." );
555+ if constexpr (has_complex_type) {
556+ type_dispatcher (lhs.data_type ,
557+ detail::single_dispatch_binary_operator{},
558+ *this ,
559+ output_object,
560+ left_row_index,
561+ right_row_index,
562+ lhs,
563+ rhs,
564+ output,
565+ output_row_index,
566+ op,
567+ thread_intermediate_storage);
568+ } else {
569+ type_dispatcher<dispatch_void_if_complex>(lhs.data_type ,
570+ detail::single_dispatch_binary_operator{},
571+ *this ,
572+ output_object,
573+ left_row_index,
574+ right_row_index,
575+ lhs,
576+ rhs,
577+ output,
578+ output_row_index,
579+ op,
580+ thread_intermediate_storage);
581+ }
495582 }
496583 }
497584 }
@@ -589,9 +676,8 @@ struct expression_evaluator {
589676 typename ResultSubclass,
590677 typename T,
591678 bool result_has_nulls,
592- std::enable_if_t <
593- detail::is_valid_unary_op<detail::operator_functor<op, has_nulls>,
594- possibly_null_value_t <Input, has_nulls>>>* = nullptr >
679+ CUDF_ENABLE_IF (detail::is_valid_unary_op<detail::operator_functor<op, has_nulls>,
680+ possibly_null_value_t <Input, has_nulls>>)>
595681 __device__ inline void operator ()(
596682 expression_result<ResultSubclass, T, result_has_nulls>& output_object,
597683 cudf::size_type const output_row_index,
@@ -613,9 +699,8 @@ struct expression_evaluator {
613699 typename ResultSubclass,
614700 typename T,
615701 bool result_has_nulls,
616- std::enable_if_t <
617- !detail::is_valid_unary_op<detail::operator_functor<op, has_nulls>,
618- possibly_null_value_t <Input, has_nulls>>>* = nullptr >
702+ CUDF_ENABLE_IF (!detail::is_valid_unary_op<detail::operator_functor<op, has_nulls>,
703+ possibly_null_value_t <Input, has_nulls>>)>
619704 __device__ inline void operator ()(
620705 expression_result<ResultSubclass, T, result_has_nulls>& output_object,
621706 cudf::size_type const output_row_index,
@@ -653,10 +738,9 @@ struct expression_evaluator {
653738 typename ResultSubclass,
654739 typename T,
655740 bool result_has_nulls,
656- std::enable_if_t <detail::is_valid_binary_op<detail::operator_functor<op, has_nulls>,
657- possibly_null_value_t <LHS, has_nulls>,
658- possibly_null_value_t <RHS, has_nulls>>>* =
659- nullptr >
741+ CUDF_ENABLE_IF (detail::is_valid_binary_op<detail::operator_functor<op, has_nulls>,
742+ possibly_null_value_t <LHS, has_nulls>,
743+ possibly_null_value_t <RHS, has_nulls>>)>
660744 __device__ inline void operator ()(
661745 expression_result<ResultSubclass, T, result_has_nulls>& output_object,
662746 cudf::size_type const output_row_index,
@@ -679,10 +763,9 @@ struct expression_evaluator {
679763 typename ResultSubclass,
680764 typename T,
681765 bool result_has_nulls,
682- std::enable_if_t <
683- !detail::is_valid_binary_op<detail::operator_functor<op, has_nulls>,
684- possibly_null_value_t <LHS, has_nulls>,
685- possibly_null_value_t <RHS, has_nulls>>>* = nullptr >
766+ CUDF_ENABLE_IF (!detail::is_valid_binary_op<detail::operator_functor<op, has_nulls>,
767+ possibly_null_value_t <LHS, has_nulls>,
768+ possibly_null_value_t <RHS, has_nulls>>)>
686769 __device__ inline void operator ()(
687770 expression_result<ResultSubclass, T, result_has_nulls>& output_object,
688771 cudf::size_type const output_row_index,
@@ -701,8 +784,4 @@ struct expression_evaluator {
701784 plan; // /< The container of device data representing the expression to evaluate.
702785};
703786
704- } // namespace detail
705-
706- } // namespace ast
707-
708- } // namespace cudf
787+ } // namespace cudf::ast::detail
0 commit comments