Skip to content

Commit ded8a1e

Browse files
authored
[CINN] [New Hardware Update]:SYCL PR6: fix sycl bugs and sycl reuse paddle stream (#74328)
* fix paddle-dcu build link error * fix sycl bugs * fix sycl codegen * fix sycl minmax bug * sycl reuse paddle stream * fix bugs * fix codegen style
1 parent 1af9eb1 commit ded8a1e

File tree

12 files changed

+92
-68
lines changed

12 files changed

+92
-68
lines changed

paddle/cinn/backends/sycl/codegen_sycl_dev.cc

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -100,38 +100,15 @@ void CodeGenSyclDevice::Visit(const ir::_LoweredFunc_ *op) {
100100
DoIndent();
101101
str_ += "h.parallel_for<class " + GenerateKernelName(op) +
102102
">(sycl::nd_range<3>(dimGrid * dimBlock, dimBlock), "
103-
"[=](sycl::nd_item<3> item) "
104-
"[[intel::kernel_args_restrict]]";
105-
if (op->cuda_axis_info.valid()) {
106-
bool has_symbol_in_thread_num = false;
107-
std::string launch_bounds_max_work_group_size =
108-
"[[intel::max_work_group_size(";
109-
for (int i = 0; i < 3; i++) {
110-
ir::Expr block_dim = op->cuda_axis_info.block_dim(i);
111-
if (block_dim.is_constant()) {
112-
launch_bounds_max_work_group_size +=
113-
std::to_string(block_dim.as_int64());
114-
if (i < 2) {
115-
launch_bounds_max_work_group_size += ", ";
116-
}
117-
} else {
118-
has_symbol_in_thread_num = true;
119-
break;
120-
}
121-
}
122-
launch_bounds_max_work_group_size += ")]]";
123-
if (!has_symbol_in_thread_num) {
124-
str_ += launch_bounds_max_work_group_size;
125-
}
126-
}
103+
"[=](sycl::nd_item<3> item) ";
127104
str_ += "\n";
128105

129106
PrintFunctionBody(op);
130107

131108
str_ += ");\n";
132109
DecIndent();
133110
DoIndent();
134-
str_ += "});\n";
111+
str_ += "}).wait();\n";
135112
DecIndent();
136113
str_ += "}\n";
137114
}
@@ -230,9 +207,9 @@ void CodeGenSyclDevice::PrintFunctionDeclaration(const ir::_LoweredFunc_ *op) {
230207
} else {
231208
CINN_NOT_IMPLEMENTED
232209
}
233-
str_ += ")(*(void **)(void_args[";
210+
str_ += ")(void_args[";
234211
str_ += std::to_string(i);
235-
str_ += "]));\n";
212+
str_ += "]);\n";
236213
}
237214
}
238215

paddle/cinn/backends/sycl/compiler_sycl.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ class Compiler {
5050
std::string compiler_path = SYCL_CXX_COMPILER;
5151
std::string prefix_dir = "./source";
5252
std::string cxx_compile_options =
53-
"-std=c++17 -fPIC -shared -ldl -fbracket-depth=1030"; // set 1030 for
54-
// constant op,
55-
// default max
56-
// bracket-depth
57-
// = 256";
53+
"-std=c++17 -fPIC -shared -w -ldl -fbracket-depth=1030"; // set 1030 for
54+
// constant op,
55+
// default max
56+
// bracket-depth
57+
// = 256";
5858
std::string device_arch_options;
5959
int compile_num = 0;
6060
std::string source_file_path;

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
14831483
pir::RewritePatternSet ps(context);
14841484
ps.Add<ScaleOpPattern>(
14851485
context); // NOTE, scale op pattern should before AddBroadcastTo
1486+
#ifndef CINN_WITH_SYCL
14861487
ps.Add<SumOpPattern>(context);
14871488
ps.Add<ReduceMinMaxOpPattern<paddle::dialect::MinOp,
14881489
cinn::dialect::ReduceMinOp>>(context);
@@ -1495,6 +1496,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
14951496
ArgMinMaxOpPattern<paddle::dialect::ArgmaxOp, cinn::dialect::ArgmaxOp>>(
14961497
context);
14971498
// Arange in this pass only handles static inputs
1499+
#endif
14981500
ps.Add<ArangeOpPattern>(context);
14991501
ps.Add<ProdOpPattern>(context);
15001502
ps.Add<ReshapeOpPattern>(context);

paddle/cinn/runtime/sycl/cinn_sycl_runtime_source.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
extern "C" {
2828

29-
#define MAX_SUBGROUP_SIZE 32
29+
#define MAX_SUBGROUP_SIZE 64
3030
#define MAX_THREADNUM_INGROUP 1024
3131
#define MAX_SUBGROUPNUM_INGROUP \
3232
((MAX_THREADNUM_INGROUP - 1) / MAX_SUBGROUP_SIZE + 1)

paddle/cinn/runtime/sycl/sycl_backend_api.cc

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#include "paddle/cinn/runtime/sycl/sycl_backend_api.h"
1616
#include <glog/logging.h>
17+
#include <hip/hip_runtime.h>
18+
#include <sycl/ext/oneapi/experimental/backend/hip.hpp>
1719

1820
namespace cinn {
1921
namespace runtime {
@@ -89,9 +91,6 @@ void SYCLBackendAPI::set_device(int device_id) {
8991
// create context and queue
9092
this->contexts[device_id] =
9193
new ::sycl::context(this->devices[device_id], exception_handler);
92-
// one device one queue
93-
this->queues[device_id].push_back(new ::sycl::queue(
94-
*this->contexts[device_id], this->devices[device_id], q_prop));
9594
}
9695
this->now_device_id = device_id;
9796
}
@@ -105,36 +104,36 @@ int SYCLBackendAPI::get_device_property(DeviceProperty device_property,
105104

106105
switch (device_property) {
107106
case DeviceProperty::MaxBlockDimX: {
108-
::sycl::_V1::id<3> max_work_item_sizes =
107+
::sycl::id<3> max_work_item_sizes =
109108
this->devices[index]
110-
.get_info<::sycl::_V1::info::device::max_work_item_sizes<3>>();
109+
.get_info<::sycl::info::device::max_work_item_sizes>();
111110
rv = max_work_item_sizes[0];
112111
break;
113112
}
114113
case DeviceProperty::MaxBlockDimY: {
115-
::sycl::_V1::id<3> max_work_item_sizes =
114+
::sycl::id<3> max_work_item_sizes =
116115
this->devices[index]
117-
.get_info<::sycl::_V1::info::device::max_work_item_sizes<3>>();
116+
.get_info<::sycl::info::device::max_work_item_sizes>();
118117
rv = max_work_item_sizes[1];
119118
break;
120119
}
121120
case DeviceProperty::MaxBlockDimZ: {
122-
::sycl::_V1::id<3> max_work_item_sizes =
121+
::sycl::id<3> max_work_item_sizes =
123122
this->devices[index]
124-
.get_info<::sycl::_V1::info::device::max_work_item_sizes<3>>();
123+
.get_info<::sycl::info::device::max_work_item_sizes>();
125124
rv = max_work_item_sizes[2];
126125
break;
127126
}
128127
case DeviceProperty::MaxGridDimX: {
129-
rv = 2097151;
128+
rv = 2147483647;
130129
break;
131130
}
132131
case DeviceProperty::MaxGridDimY: {
133-
rv = 2097151;
132+
rv = 2147483647;
134133
break;
135134
}
136135
case DeviceProperty::MaxGridDimZ: {
137-
rv = 2097151;
136+
rv = 2147483647;
138137
break;
139138
}
140139
case DeviceProperty::MaxSharedMemoryPerBlock: {
@@ -239,7 +238,27 @@ void SYCLBackendAPI::stream_sync(void* stream) {
239238
SYCL_CALL(static_cast<::sycl::queue*>(stream)->wait_and_throw());
240239
}
241240

242-
::sycl::queue* SYCLBackendAPI::get_now_queue() {
241+
::sycl::queue* SYCLBackendAPI::get_now_queue(void* raw_stream) {
242+
if (this->queues[now_device_id].size() == 0) {
243+
int current_device_id;
244+
hipGetDevice(&current_device_id);
245+
hipSetDevice(current_device_id);
246+
hipDeviceGet(&device_, current_device_id);
247+
hipCtxGetCurrent(&context_);
248+
hipDevicePrimaryCtxRetain(&context_, device_);
249+
250+
::sycl::backend_input_t<::sycl::backend::ext_oneapi_hip, ::sycl::context>
251+
InteropContextInput{context_};
252+
::sycl::context InteropContext =
253+
::sycl::make_context<::sycl::backend::ext_oneapi_hip>(
254+
InteropContextInput);
255+
256+
hipStream_t hipStream = static_cast<hipStream_t>(raw_stream);
257+
auto Q =
258+
new ::sycl::queue(::sycl::make_queue<::sycl::backend::ext_oneapi_hip>(
259+
hipStream, InteropContext));
260+
this->queues[now_device_id].push_back(Q);
261+
}
243262
return this->queues[now_device_id][0];
244263
}
245264

@@ -276,9 +295,9 @@ std::array<int, 3> SYCLBackendAPI::get_max_block_dims(
276295
std::optional<int> device_id) {
277296
std::array<int, 3> kMaxBlockDims;
278297
int index = device_id.value_or(this->now_device_id);
279-
::sycl::_V1::id<3> max_work_item_sizes =
298+
::sycl::id<3> max_work_item_sizes =
280299
this->devices[index]
281-
.get_info<::sycl::_V1::info::device::max_work_item_sizes<3>>();
300+
.get_info<::sycl::info::device::max_work_item_sizes>();
282301
kMaxBlockDims = std::array<int, 3>{
283302
max_work_item_sizes[2], max_work_item_sizes[1], max_work_item_sizes[0]};
284303
return kMaxBlockDims;
@@ -287,7 +306,7 @@ std::array<int, 3> SYCLBackendAPI::get_max_block_dims(
287306
std::array<int, 3> SYCLBackendAPI::get_max_grid_dims(
288307
std::optional<int> device_id) {
289308
std::array<int, 3> kMaxGridDims;
290-
kMaxGridDims = std::array<int, 3>{2097151, 2097151, 2097151};
309+
kMaxGridDims = std::array<int, 3>{2147483647, 2147483647, 2147483647};
291310
return kMaxGridDims;
292311
}
293312

paddle/cinn/runtime/sycl/sycl_backend_api.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#pragma once
1616

1717
#include <sycl/sycl.hpp>
18+
19+
#include <hip/hip_runtime.h>
1820
#include <vector>
1921
#include "paddle/cinn/common/macros.h"
2022
#include "paddle/cinn/common/target.h"
@@ -103,7 +105,7 @@ class SYCLBackendAPI final : public BackendAPI {
103105
MemcpyType type) final;
104106
void device_sync() final;
105107
void stream_sync(void* stream) final;
106-
::sycl::queue* get_now_queue();
108+
::sycl::queue* get_now_queue(void* stream);
107109
std::string GetGpuVersion();
108110
std::array<int, 3> get_max_grid_dims(
109111
std::optional<int> device_id = std::nullopt) final;
@@ -118,9 +120,11 @@ class SYCLBackendAPI final : public BackendAPI {
118120
// all queues in all devices
119121
std::vector<std::vector<::sycl::queue*>> queues;
120122
// now_device_id, change by set_device()
121-
int now_device_id = -1;
123+
int now_device_id = 0;
122124
// whether the BackendAPI is initialized.
123125
bool initialized_{false};
126+
hipDevice_t device_;
127+
hipCtx_t context_;
124128
};
125129
} // namespace sycl
126130
} // namespace runtime

paddle/cinn/runtime/sycl/sycl_intrinsics.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,8 @@ CINN_REGISTER_HELPER(cinn_sycl_host_api) {
445445
.AddInputType<int>() // block_x
446446
.AddInputType<int>() // block_y
447447
.AddInputType<int>() // block_z
448+
.AddInputType<int>() // shared_memory_bytes
449+
.AddInputType<void *>() // stream
448450
.End();
449451
using cinn::runtime::sycl::infer_shape_set_value;
450452

@@ -465,6 +467,7 @@ CINN_REGISTER_HELPER(cinn_sycl_host_api) {
465467
.AddInputType<int>() // num_args
466468
.AddInputType<int>() // value
467469
.AddInputType<size_t>() // count
470+
.AddInputType<void *>() // stream
468471
.End();
469472

470473
using cinn::runtime::sycl::cinn_call_sycl_memcpy;
@@ -474,6 +477,7 @@ CINN_REGISTER_HELPER(cinn_sycl_host_api) {
474477
.AddInputType<void *>() // v_args
475478
.AddInputType<int>() // num_args
476479
.AddInputType<size_t>() // count
480+
.AddInputType<void *>() // stream
477481
.End();
478482

479483
#ifdef CINN_WITH_CNNL

paddle/cinn/runtime/sycl/sycl_module.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <glog/logging.h>
1717
#include <glog/raw_logging.h>
1818

19+
#include <hip/hip_runtime.h>
1920
#include "paddle/cinn/runtime/cinn_runtime.h"
2021
#include "paddle/cinn/runtime/sycl/sycl_backend_api.h"
2122
#include "paddle/cinn/runtime/sycl/sycl_module.h"
@@ -38,6 +39,7 @@ SYCLModule::SYCLModule(const std::string& source_code,
3839
SYCLModule::~SYCLModule() { VLOG(3) << "destructor for SYCLModule"; }
3940

4041
void* SYCLModule::GetFunction(const std::string& func_name) {
42+
std::lock_guard<std::mutex> lock(mutex_);
4143
if (so_handler_ == nullptr) {
4244
so_handler_ = dlopen(shared_library_.c_str(), RTLD_NOW | RTLD_GLOBAL);
4345
}

paddle/cinn/runtime/sycl/sycl_util.cc

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222
#include "paddle/cinn/runtime/sycl/sycl_util.h"
2323
#include "paddle/cinn/utils/profiler.h"
2424
#include "paddle/common/enforce.h"
25-
#ifdef CINN_WITH_CNNL
26-
#include <cn_api.h>
27-
#include <cnnl.h>
28-
#include <CL/sycl/backend/cnrt.hpp>
25+
#ifdef CINN_WITH_HIP
26+
#include <hip/hip_runtime.h>
2927
#endif
3028

3129
namespace cinn {
@@ -40,7 +38,9 @@ void cinn_call_sycl_kernel(void *kernel_fn,
4038
int grid_z,
4139
int block_x,
4240
int block_y,
43-
int block_z) {
41+
int block_z,
42+
int shared_memory_bytes,
43+
void *stream) {
4444
VLOG(3) << "cinn_call_sycl_kernel, grid_dim={" << grid_x << ", " << grid_y
4545
<< ", " << grid_z << "}, block_dim={" << block_x << ", " << block_y
4646
<< ", " << block_z << "}, num_args=" << num_args;
@@ -58,7 +58,7 @@ void cinn_call_sycl_kernel(void *kernel_fn,
5858
ss << std::hex << addr;
5959
VLOG(4) << "sycl kernel arg[" << idx
6060
<< "] is a buffer, addr=" << ss.str();
61-
kernel_args.emplace_back(&addr);
61+
kernel_args.emplace_back(addr);
6262
} else {
6363
kernel_args.emplace_back((args[idx].data_addr()));
6464
}
@@ -76,7 +76,7 @@ void cinn_call_sycl_kernel(void *kernel_fn,
7676
::sycl::range<3> k0_dimGrid,
7777
::sycl::range<3> k0_dimBlock,
7878
void **void_args))(kernel_fn);
79-
::sycl::queue *Queue = SYCLBackendAPI::Global()->get_now_queue();
79+
::sycl::queue *Queue = SYCLBackendAPI::Global()->get_now_queue(stream);
8080
::sycl::range<3> Grid(grid_z, grid_y, grid_x);
8181
::sycl::range<3> Block(block_z, block_y, block_x);
8282
SYCL_CALL(kernel_func(*Queue, Grid, Block, kernel_args.data()));
@@ -87,10 +87,8 @@ void infer_shape_set_value(int row, int col, int64_t value, int64_t **v) {
8787
v[row][col] = value;
8888
}
8989

90-
void cinn_call_sycl_memset(void *v_args,
91-
int num_args,
92-
int value,
93-
size_t count) {
90+
void cinn_call_sycl_memset(
91+
void *v_args, int num_args, int value, size_t count, void *stream) {
9492
PADDLE_ENFORCE_EQ(num_args,
9593
1,
9694
::common::errors::PreconditionNotMet(
@@ -105,12 +103,15 @@ void cinn_call_sycl_memset(void *v_args,
105103
ss << std::hex << output;
106104
VLOG(4) << "cinn_call_sycl_memset: " << ss.str();
107105

108-
auto Queue = SYCLBackendAPI::Global()->get_now_queue();
106+
::sycl::queue *Queue = SYCLBackendAPI::Global()->get_now_queue(stream);
109107

110108
SYCL_CALL(Queue->memset(output, value, count));
111109
}
112110

113-
void cinn_call_sycl_memcpy(void *v_args, int num_args, size_t count) {
111+
void cinn_call_sycl_memcpy(void *v_args,
112+
int num_args,
113+
size_t count,
114+
void *stream) {
114115
PADDLE_ENFORCE_EQ(
115116
num_args,
116117
2,
@@ -132,7 +133,7 @@ void cinn_call_sycl_memcpy(void *v_args, int num_args, size_t count) {
132133
ss << std::hex << input << " -> " << output;
133134
VLOG(4) << "cinn_call_sycl_memcpy: " << ss.str();
134135

135-
auto Queue = SYCLBackendAPI::Global()->get_now_queue();
136+
::sycl::queue *Queue = SYCLBackendAPI::Global()->get_now_queue(stream);
136137

137138
SYCL_CALL(Queue->memcpy(output, input, count));
138139
}

paddle/cinn/runtime/sycl/sycl_util.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,22 @@ void cinn_call_sycl_kernel(void* kernel_fn,
4343
int grid_z,
4444
int block_x,
4545
int block_y,
46-
int block_z);
46+
int block_z,
47+
int shared_memory_bytes,
48+
void* stream);
4749

4850
void infer_shape_set_value(int row, int col, int64_t value, int64_t** v);
4951

50-
void cinn_call_sycl_memset(void* v_args, int num_args, int value, size_t count);
52+
void cinn_call_sycl_memset(void* v_args,
53+
int num_args,
54+
int value,
55+
size_t count,
56+
void* stream = nullptr);
5157

52-
void cinn_call_sycl_memcpy(void* v_args, int num_args, size_t count);
58+
void cinn_call_sycl_memcpy(void* v_args,
59+
int num_args,
60+
size_t count,
61+
void* stream = nullptr);
5362

5463
#ifdef CINN_WITH_CNNL
5564
#define CNRT_CALL(func) \

0 commit comments

Comments
 (0)