Skip to content

Commit 9df3eef

Browse files
authored
[CINN] support dcu get_value_in_kernel_args with sycl and hip (#74266)
1 parent c917219 commit 9df3eef

File tree

10 files changed

+19
-22
lines changed

10 files changed

+19
-22
lines changed

paddle/cinn/backends/codegen_device_util.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,12 @@ ir::Module CreateSwitchWithBroadcastConditionModule(
4444
ir::Argument(kernel_args, ir::Argument::IO::kOutput),
4545
ir::Argument(kernel_args_num, ir::Argument::IO::kInput),
4646
ir::Argument(tensor_shape_args, ir::Argument::IO::kOutput)};
47-
4847
const auto &symbolic_arg_define = [&]() -> std::vector<ir::Expr> {
4948
std::vector<ir::Expr> arg_defs;
5049
for (const auto &item : symbolic_shape_var_index) {
5150
ir::Expr call_get_value_in_kernel_args =
5251
ir::Call::Make(Int(64),
53-
runtime::intrinsic::get_value_in_cuda_kernel_args,
52+
runtime::intrinsic::get_value_in_kernel_args,
5453
{kernel_args, ir::Expr(item.first)},
5554
{},
5655
ir::CallType::Extern,
@@ -384,7 +383,7 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
384383
if (args[i].is_var()) {
385384
ir::Expr call_get_value_in_kernel_args =
386385
ir::Call::Make(Int(64),
387-
runtime::intrinsic::get_value_in_cuda_kernel_args,
386+
runtime::intrinsic::get_value_in_kernel_args,
388387
{kernel_args_, ir::Expr(i)},
389388
{},
390389
ir::CallType::Extern,

paddle/cinn/backends/codegen_invoke_module.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class CodeGenSwitchHost : public CodeGenInvokeModule {
6868
: CodeGenInvokeModule(m, b, vars) {}
6969
// only support call of args get function and inner case host function call
7070
llvm::Value *Visit(const ir::Call *op) override {
71-
if (op->name == runtime::intrinsic::get_value_in_cuda_kernel_args) {
71+
if (op->name == runtime::intrinsic::get_value_in_kernel_args) {
7272
return CodeGenLLVM::Visit(op);
7373
} else {
7474
return LowerInnerCaseCall(op);

paddle/cinn/runtime/cpu/host_intrinsics.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ inline int64_t FN_INT64(logical_right_shift)(int64_t x, int64_t y) {
280280
}
281281

282282
#undef FN_INT64
283+
284+
int64_t cinn_get_value_in_kernel_args(void* v_args, int idx) {
285+
cinn_pod_value_t* args = static_cast<cinn_pod_value_t*>(v_args);
286+
return args[idx].operator int64_t();
287+
}
288+
283289
} // extern "C"
284290

285291
CINN_REGISTER_HELPER(host_intrinsics) {
@@ -469,5 +475,11 @@ CINN_REGISTER_HELPER(host_intrinsics) {
469475
.AddInputType<int>()
470476
.End();
471477

478+
REGISTER_EXTERN_FUNC_HELPER(cinn_get_value_in_kernel_args, host_target)
479+
.SetRetType<int64_t>()
480+
.AddInputType<void*>() // args
481+
.AddInputType<int>() // index
482+
.End();
483+
472484
return true;
473485
}

paddle/cinn/runtime/cpu/host_intrinsics.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,6 @@ inline float FN_FP32(cbrt)(float x);
121121
inline double FN_FP64(cbrt)(double x);
122122

123123
#undef FN_FP64
124+
125+
int64_t cinn_get_value_in_kernel_args(void* v_args, int idx);
124126
}

paddle/cinn/runtime/cuda/cuda_intrinsics.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -431,14 +431,6 @@ CINN_REGISTER_HELPER(cuda_intrinsics) {
431431
}
432432

433433
CINN_REGISTER_HELPER(cinn_cuda_host_api) {
434-
using cinn::runtime::cuda::cinn_get_value_in_cuda_kernel_args;
435-
REGISTER_EXTERN_FUNC_HELPER(cinn_get_value_in_cuda_kernel_args,
436-
cinn::common::DefaultHostTarget())
437-
.SetRetType<int64_t>()
438-
.AddInputType<void *>() // args
439-
.AddInputType<int>() // index
440-
.End();
441-
442434
using cinn::runtime::cuda::cinn_get_item_in_cuda_kernel_args;
443435
REGISTER_EXTERN_FUNC_HELPER(cinn_get_item_in_cuda_kernel_args,
444436
cinn::common::DefaultHostTarget())

paddle/cinn/runtime/cuda/cuda_util.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,6 @@ class CublasHandle {
7878
cublasHandle_t cuhandle;
7979
};
8080

81-
int64_t cinn_get_value_in_cuda_kernel_args(void *v_args, int idx) {
82-
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
83-
return args[idx].operator int64_t();
84-
}
85-
8681
void *cinn_get_item_in_cuda_kernel_args(void *v_args, int idx) {
8782
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
8883
return static_cast<void *>(&args[idx]);

paddle/cinn/runtime/cuda/cuda_util.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ void cinn_call_cuda_memcpy(void* v_args,
8989
size_t count,
9090
void* stream = nullptr);
9191

92-
int64_t cinn_get_value_in_cuda_kernel_args(void* v_args, int idx);
9392
void* cinn_get_item_in_cuda_kernel_args(void* v_args, int idx);
9493

9594
void infer_shape_set_value(int row, int col, int64_t value, int64_t** v);

paddle/cinn/runtime/hip/hip_intrinsics.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ using cinn::backends::GlobalSymbolRegistry;
1717
using cinn::runtime::hip::HIPBackendAPI;
1818
#include "paddle/cinn/backends/extern_func_jit_register.h"
1919
#include "paddle/cinn/runtime/hip/hip_util.h"
20-
using cinn::runtime::hip::cinn_call_hip_kernel;
2120

2221
CINN_REGISTER_HELPER(cinn_hip_host_api) {
2322
GlobalSymbolRegistry::Global().RegisterFn(
2423
"backend_api.hip", reinterpret_cast<void *>(HIPBackendAPI::Global()));
2524

25+
using cinn::runtime::hip::cinn_call_hip_kernel;
2626
REGISTER_EXTERN_FUNC_HELPER(cinn_call_hip_kernel,
2727
cinn::common::DefaultHostTarget())
2828
.SetRetType<void>()

paddle/cinn/runtime/intrinsic.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ static const char* call_sycl_kernel = "cinn_call_sycl_kernel";
113113

114114
static const char* call_cuda_memset = "cinn_call_cuda_memset";
115115

116-
static const char* get_value_in_cuda_kernel_args =
117-
"cinn_get_value_in_cuda_kernel_args";
116+
static const char* get_value_in_kernel_args = "cinn_get_value_in_kernel_args";
118117

119118
static const char* get_item_in_cuda_kernel_args =
120119
"cinn_get_item_in_cuda_kernel_args";

paddle/cinn/runtime/sycl/sycl_util.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ void cinn_call_sycl_memcpy(void *v_args,
139139
}
140140

141141
#ifdef CINN_WITH_CNNL
142-
143142
class CnnlHandle {
144143
public:
145144
CnnlHandle(const CnnlHandle &) = delete;

0 commit comments

Comments
 (0)