Skip to content

Commit 59e63e6

Browse files
author
Yihan Wang
authored
[SYCLomatic] Refine migration of cub::DeviceSegmentedReduce::Reduce (#471)
Introduce an new value for option -use-experimental-features=user-defined-reductions to migrate cub::DeviceSegmentedReduce::Reduce to dpct::device::experimental::segmented_reduce. Signed-off-by: Wang, Yihan <[email protected]>
1 parent 15fe0e8 commit 59e63e6

File tree

10 files changed

+424
-30
lines changed

10 files changed

+424
-30
lines changed

clang/include/clang/DPCT/DPCTOptions.inc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,9 @@ DPCT_ENUM_OPTION(DPCT_OPT_TYPE(static llvm::cl::bits<ExperimentalFeatures>), Exp
369369
false),
370370
DPCT_OPT_ENUM("nd_range_barrier", int(ExperimentalFeatures::Exp_NdRangeBarrier),
371371
"Experimental helper function used to help cross group synchronization during migration.\n",
372+
false),
373+
DPCT_OPT_ENUM("user-defined-reductions", int(ExperimentalFeatures::Exp_UserDefineReductions),
374+
"Experimental extension that allows user define reductions.\n",
372375
false)
373376
),
374377
llvm::cl::desc("Comma separated list of experimental features to be used in migrated "

clang/lib/DPCT/APINamesCUB.inc

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -588,60 +588,93 @@ CONDITIONAL_FACTORY_ENTRY(
588588
CONDITIONAL_FACTORY_ENTRY(
589589
CheckCubRedundantFunctionCall(),
590590
REMOVE_API_FACTORY_ENTRY("cub::DeviceSegmentedReduce::Reduce"),
591-
REMOVE_CUB_TEMP_STORAGE_FACTORY(FEATURE_REQUEST_FACTORY(
592-
HelperFeatureEnum::DplExtrasDpcppExtensions_segmented_reduce,
593-
HEADER_INSERT_FACTORY(
594-
HeaderType::HT_DPCT_DPL_Utils,
595-
WARNING_FACTORY_ENTRY(
596-
"cub::DeviceSegmentedReduce::Reduce",
591+
REMOVE_CUB_TEMP_STORAGE_FACTORY(HEADER_INSERT_FACTORY(
592+
HeaderType::HT_DPCT_DPL_Utils,
593+
WARNING_FACTORY_ENTRY(
594+
"cub::DeviceSegmentedReduce::Reduce",
595+
CONDITIONAL_FACTORY_ENTRY(
596+
makeCheckAnd(CheckArgCount(10, std::greater_equal<>(),
597+
/* IncludeDefaultArg */ false),
598+
makeCheckNot(CheckArgIsDefaultCudaStream(9))),
597599
CONDITIONAL_FACTORY_ENTRY(
598-
makeCheckAnd(CheckArgCount(10, std::greater_equal<>(), /* IncludeDefaultArg */false),
599-
makeCheckNot(CheckArgIsDefaultCudaStream(9))),
600-
CONDITIONAL_FACTORY_ENTRY(
601-
checkArgCanMappingToSyclNativeBinaryOp(7),
600+
checkEnableUserDefineReductions(),
601+
FEATURE_REQUEST_FACTORY(
602+
HelperFeatureEnum::
603+
DplExtrasDpcppExtensions_segmented_reduce_ext,
602604
CALL_FACTORY_ENTRY(
603605
"cub::DeviceSegmentedReduce::Reduce",
604606
CALL(TEMPLATED_CALLEE_WITH_ARGS(
605607
MapNames::getDpctNamespace() +
606-
"device::segmented_reduce",
608+
"device::experimental::segmented_"
609+
"reduce",
607610
LITERAL("128")),
608611
STREAM(9), ARG(2), ARG(3), ARG(4), ARG(5),
609-
ARG(6), ARG(7), ARG(8))),
610-
WARNING_FACTORY_ENTRY(
611-
"cub::DeviceSegmentedReduce::Reduce",
612+
ARG(6), ARG(7), ARG(8)))),
613+
FEATURE_REQUEST_FACTORY(
614+
HelperFeatureEnum::
615+
DplExtrasDpcppExtensions_segmented_reduce,
616+
CONDITIONAL_FACTORY_ENTRY(
617+
checkArgCanMappingToSyclNativeBinaryOp(7),
612618
CALL_FACTORY_ENTRY(
613619
"cub::DeviceSegmentedReduce::Reduce",
614620
CALL(TEMPLATED_CALLEE_WITH_ARGS(
615621
MapNames::getDpctNamespace() +
616622
"device::segmented_reduce",
617623
LITERAL("128")),
618624
STREAM(9), ARG(2), ARG(3), ARG(4), ARG(5),
619-
ARG(6), LITERAL("dpct_placeholder"),
620-
ARG(8))),
621-
Diagnostics::UNSUPPORTED_BINARY_OPERATION)),
622-
CONDITIONAL_FACTORY_ENTRY(
623-
checkArgCanMappingToSyclNativeBinaryOp(7),
625+
ARG(6), ARG(7), ARG(8))),
626+
WARNING_FACTORY_ENTRY(
627+
"cub::DeviceSegmentedReduce::Reduce",
628+
CALL_FACTORY_ENTRY(
629+
"cub::DeviceSegmentedReduce::Reduce",
630+
CALL(TEMPLATED_CALLEE_WITH_ARGS(
631+
MapNames::getDpctNamespace() +
632+
"device::segmented_reduce",
633+
LITERAL("128")),
634+
STREAM(9), ARG(2), ARG(3), ARG(4),
635+
ARG(5), ARG(6),
636+
LITERAL("dpct_placeholder"), ARG(8))),
637+
Diagnostics::UNSUPPORTED_BINARY_OPERATION)))),
638+
CONDITIONAL_FACTORY_ENTRY(
639+
checkEnableUserDefineReductions(),
640+
FEATURE_REQUEST_FACTORY(
641+
HelperFeatureEnum::
642+
DplExtrasDpcppExtensions_segmented_reduce_ext,
624643
CALL_FACTORY_ENTRY(
625644
"cub::DeviceSegmentedReduce::Reduce",
626645
CALL(TEMPLATED_CALLEE_WITH_ARGS(
627646
MapNames::getDpctNamespace() +
628-
"device::segmented_reduce",
647+
"device::experimental::segmented_"
648+
"reduce",
629649
LITERAL("128")),
630650
QUEUESTR, ARG(2), ARG(3), ARG(4), ARG(5),
631-
ARG(6), ARG(7), ARG(8))),
632-
WARNING_FACTORY_ENTRY(
633-
"cub::DeviceSegmentedReduce::Reduce",
651+
ARG(6), ARG(7), ARG(8)))),
652+
FEATURE_REQUEST_FACTORY(
653+
HelperFeatureEnum::
654+
DplExtrasDpcppExtensions_segmented_reduce,
655+
CONDITIONAL_FACTORY_ENTRY(
656+
checkArgCanMappingToSyclNativeBinaryOp(7),
634657
CALL_FACTORY_ENTRY(
635658
"cub::DeviceSegmentedReduce::Reduce",
636659
CALL(TEMPLATED_CALLEE_WITH_ARGS(
637660
MapNames::getDpctNamespace() +
638661
"device::segmented_reduce",
639662
LITERAL("128")),
640663
QUEUESTR, ARG(2), ARG(3), ARG(4), ARG(5),
641-
ARG(6), LITERAL("dpct_placeholder"),
642-
ARG(8))),
643-
Diagnostics::UNSUPPORTED_BINARY_OPERATION))),
644-
Diagnostics::REDUCE_PERFORMANCE_TUNE)))))
664+
ARG(6), ARG(7), ARG(8))),
665+
WARNING_FACTORY_ENTRY(
666+
"cub::DeviceSegmentedReduce::Reduce",
667+
CALL_FACTORY_ENTRY(
668+
"cub::DeviceSegmentedReduce::Reduce",
669+
CALL(TEMPLATED_CALLEE_WITH_ARGS(
670+
MapNames::getDpctNamespace() +
671+
"device::segmented_reduce",
672+
LITERAL("128")),
673+
QUEUESTR, ARG(2), ARG(3), ARG(4),
674+
ARG(5), ARG(6),
675+
LITERAL("dpct_placeholder"), ARG(8))),
676+
Diagnostics::UNSUPPORTED_BINARY_OPERATION))))),
677+
Diagnostics::REDUCE_PERFORMANCE_TUNE))))
645678

646679
// cub::DeviceSegmentedReduce::Sum
647680
CONDITIONAL_FACTORY_ENTRY(

clang/lib/DPCT/AnalysisInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,6 +1860,9 @@ class DpctGlobalInfo {
18601860
static bool useLogicalGroup() {
18611861
return getUsingExperimental<ExperimentalFeatures::Exp_LogicalGroup>();
18621862
}
1863+
static bool useUserDefineReductions() {
1864+
return getUsingExperimental<ExperimentalFeatures::Exp_UserDefineReductions>();
1865+
}
18631866
static bool useEnqueueBarrier() {
18641867
return getUsingExtensionDE(DPCPPExtensionsDefaultEnabled::ExtDE_EnqueueBarrier);
18651868
}

clang/lib/DPCT/CUBAPIMigration.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,6 @@ void CubTypeRule::runRule(
9898
}
9999
}
100100

101-
/// Remove this function when the support for user-define operator in
102-
/// reduce_over_group() is available
103101
bool CubTypeRule::CanMappingToSyclNativeBinaryOp(StringRef OpTypeName) {
104102
return OpTypeName == "cub::Sum" || OpTypeName == "cub::Max" ||
105103
OpTypeName == "cub::Min";

clang/lib/DPCT/CallExprRewriterCUB.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ RemoveCubTempStorageFactory::create(const CallExpr *C) const {
2626
return Inner->create(C);
2727
}
2828

29+
std::function<bool(const CallExpr *)>
30+
checkEnableUserDefineReductions() {
31+
return [=](const CallExpr *) -> bool {
32+
return DpctGlobalInfo::useUserDefineReductions();
33+
};
34+
}
35+
2936
std::function<bool(const CallExpr *)>
3037
checkArgCanMappingToSyclNativeBinaryOp(size_t ArgIdx) {
3138
return [=](const CallExpr *C) -> bool {

clang/lib/DPCT/ValidateArguments.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ enum class ExperimentalFeatures : unsigned int {
7272
// this_group, this_subgroup.
7373
Exp_GroupSharedMemory,
7474
Exp_LogicalGroup,
75-
Exp_ExperimentalFeaturesEnumSize
75+
Exp_ExperimentalFeaturesEnumSize,
76+
Exp_UserDefineReductions
7677
};
7778

7879
bool makeInRootCanonicalOrSetDefaults(

clang/runtime/dpct-rt/include/dpl_extras/dpcpp_extensions.h.inc

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
// DPCT_CODE
1919
#include <sycl/sycl.hpp>
2020
#include <stdexcept>
21+
22+
#ifdef SYCL_EXT_ONEAPI_USER_DEFINED_REDUCTIONS
23+
#include <sycl/ext/oneapi/experimental/user_defined_reductions.hpp>
24+
#endif
2125
// DPCT_LABEL_END
2226

2327
namespace dpct {
@@ -709,6 +713,97 @@ void segmented_reduce(sycl::queue queue, T *inputs, T *outputs,
709713
});
710714
});
711715
}
716+
717+
// DPCT_LABEL_END
718+
719+
// DPCT_LABEL_BEGIN|segmented_reduce_ext|dpct::device::experimental
720+
// DPCT_DEPENDENCY_BEGIN
721+
// DplExtrasDpcppExtensions|segmented_reduce
722+
// DPCT_DEPENDENCY_END
723+
// DPCT_CODE
724+
#ifdef SYCL_EXT_ONEAPI_USER_DEFINED_REDUCTIONS
725+
726+
namespace experimental {
727+
namespace detail {
728+
template <typename _Tp, typename... _Ts> struct __is_any {
729+
constexpr static bool value = std::disjunction_v<
730+
std::is_same<std::remove_cv_t<_Tp>, std::remove_cv_t<_Ts>>...>;
731+
};
732+
733+
template <typename _Tp, typename _Bp> struct __in_native_op_list {
734+
constexpr static bool value =
735+
__is_any<_Bp, sycl::plus<_Tp>, sycl::bit_or<_Tp>, sycl::bit_xor<_Tp>,
736+
sycl::bit_and<_Tp>, sycl::maximum<_Tp>, sycl::minimum<_Tp>,
737+
sycl::multiplies<_Tp>>::value;
738+
};
739+
740+
template <typename _Tp, typename _Bp> struct __is_native_op {
741+
constexpr static bool value = __in_native_op_list<_Tp, _Bp>::value ||
742+
__in_native_op_list<void, _Bp>::value;
743+
};
744+
745+
} // namespace detail
746+
747+
/// Perform a reduce on each of the segments specified within data stored on
748+
/// the device. Compared with dpct::device::segmented_reduce, this experimental
749+
/// feature support user define reductions.
750+
///
751+
/// \param queue Command queue used to access device used for reduction
752+
/// \param inputs Pointer to the data elements on the device to be reduced
753+
/// \param outputs Pointer to the storage where the reduced value for each
754+
/// segment will be stored \param segment_count number of segments to be reduced
755+
/// \param begin_offsets Pointer to the set of indices that are the first
756+
/// element in each segment \param end_offsets Pointer to the set of indices
757+
/// that are one past the last element in each segment \param binary_op functor
758+
/// that implements the binary operation used to perform the scan. \param init
759+
/// initial value of the reduction for each segment.
760+
template <int GROUP_SIZE, typename T, typename OffsetT, class BinaryOperation>
761+
void segmented_reduce(sycl::queue queue, T *inputs, T *outputs,
762+
size_t segment_count, OffsetT *begin_offsets,
763+
OffsetT *end_offsets, BinaryOperation binary_op, T init) {
764+
765+
sycl::range<1> global_size(segment_count * GROUP_SIZE);
766+
sycl::range<1> local_size(GROUP_SIZE);
767+
768+
if constexpr (!detail::__is_native_op<T, BinaryOperation>::value) {
769+
queue.submit([&](sycl::handler &cgh) {
770+
size_t temp_memory_size = GROUP_SIZE * sizeof(T);
771+
auto scratch = sycl::local_accessor<std::byte, 1>(temp_memory_size, cgh);
772+
cgh.parallel_for(
773+
sycl::nd_range<1>(global_size, local_size),
774+
[=](sycl::nd_item<1> item) {
775+
OffsetT segment_begin = begin_offsets[item.get_group_linear_id()];
776+
OffsetT segment_end = end_offsets[item.get_group_linear_id()];
777+
if (segment_begin == segment_end) {
778+
if (item.get_local_linear_id() == 0) {
779+
outputs[item.get_group_linear_id()] = init;
780+
}
781+
return;
782+
}
783+
// Create a handle that associates the group with an allocation it
784+
// can use
785+
auto handle =
786+
sycl::ext::oneapi::experimental::group_with_scratchpad(
787+
item.get_group(),
788+
sycl::span(&scratch[0], temp_memory_size));
789+
T group_aggregate = sycl::ext::oneapi::experimental::joint_reduce(
790+
handle, inputs + segment_begin, inputs + segment_end, init,
791+
binary_op);
792+
if (item.get_local_linear_id() == 0) {
793+
outputs[item.get_group_linear_id()] = group_aggregate;
794+
}
795+
});
796+
});
797+
} else {
798+
dpct::device::segmented_reduce<GROUP_SIZE>(queue, inputs, outputs,
799+
segment_count, begin_offsets,
800+
end_offsets, binary_op, init);
801+
}
802+
}
803+
} // namespace experimental
804+
805+
#endif // SYCL_EXT_ONEAPI_USER_DEFINED_REDUCTIONS
806+
712807
// DPCT_LABEL_END
713808

714809
} // namespace device

0 commit comments

Comments
 (0)