Skip to content

Commit 9cd7344

Browse files
author
Lu Teng
authored
Refine GPU runtime code (#426)
1 parent 13a0da0 commit 9cd7344

File tree

2 files changed

+31
-91
lines changed

2 files changed

+31
-91
lines changed

xla/stream_executor/sycl/sycl_gpu_runtime.cc

Lines changed: 27 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,6 @@ limitations under the License.
2828

2929
namespace {
3030

31-
// SYCL_TILE_AS_DEVICE
32-
// True (default behaviour): Tile as an individual device in device list
33-
// False: Only root device as an individual device in device list
34-
inline bool TileAsDevice() {
35-
bool tile_as_device;
36-
TF_CHECK_OK(
37-
tsl::ReadBoolFromEnvVar("SYCL_TILE_AS_DEVICE", true, &tile_as_device));
38-
return tile_as_device;
39-
}
40-
41-
inline bool RunOnLevelZero() {
42-
char* sycl_device_filter = getenv("SYCL_DEVICE_FILTER");
43-
// Current default backend platform is Level-Zero
44-
if (sycl_device_filter == nullptr) return true;
45-
auto filter_device = std::string(sycl_device_filter);
46-
std::transform(filter_device.begin(), filter_device.end(),
47-
filter_device.begin(),
48-
[](unsigned char c) { return std::tolower(c); });
49-
return filter_device.find("level_zero") != std::string::npos;
50-
}
51-
5231
bool hasDevice() {
5332
int count = 0;
5433
SYCLError_t error = SYCLGetDeviceCount(&count);
@@ -125,13 +104,9 @@ class DevicePool {
125104
auto platform_list = sycl::platform::get_platforms();
126105
for (const auto& platform : platform_list) {
127106
auto platform_name = platform.get_info<sycl::info::platform::name>();
128-
bool is_level_zero =
129-
platform_name.find("Level-Zero") != std::string::npos;
130-
// Add device in these two scenarios:
131-
// true == true means need Level-Zero and the backend platform is
132-
// Level-Zero.
133-
// false == false mean need OCL and the backend platform is OCL.
134-
if (is_level_zero == RunOnLevelZero()) {
107+
bool is_found = platform_name.find("Level-Zero") != std::string::npos;
108+
109+
if (is_found) {
135110
LOG(INFO) << "Selected platform: " << platform_name;
136111
auto device_list = platform.get_devices();
137112
for (const auto& device : device_list) {
@@ -142,37 +117,7 @@ class DevicePool {
142117
}
143118
}
144119

145-
if (TileAsDevice()) {
146-
// If SYCL_TILE_AS_DEVICE is true.
147-
// Create sub devices from root devices:
148-
// If succ, add sub devices into devices list
149-
// If fail, add root devices into devices list
150-
constexpr auto partition_by_affinity =
151-
sycl::info::partition_property::partition_by_affinity_domain;
152-
constexpr auto next_partitionable =
153-
sycl::info::partition_affinity_domain::next_partitionable;
154-
for (const auto& root_device : root_devices) {
155-
std::vector<sycl::device> sub_devices;
156-
auto max_sub_devices =
157-
root_device
158-
.get_info<sycl::info::device::partition_max_sub_devices>();
159-
if (max_sub_devices == 0) {
160-
LOG(INFO) << "number of sub-devices is zero, expose root "
161-
"device.";
162-
devices.push_back(root_device);
163-
} else {
164-
sub_devices = root_device.create_sub_devices<partition_by_affinity>(
165-
next_partitionable);
166-
devices.insert(devices.end(), sub_devices.begin(),
167-
sub_devices.end());
168-
}
169-
}
170-
} else {
171-
// If SYCL_TILE_AS_DEVICE is false.
172-
// Only set root device as device list.
173-
devices = std::move(root_devices);
174-
}
175-
120+
devices = std::move(root_devices);
176121
size_t num_device = devices.size();
177122

178123
if (num_device <= 0) {
@@ -226,27 +171,27 @@ static sycl::async_handler SYCLAsyncHandler = [](sycl::exception_list eL) {
226171
try {
227172
std::rethrow_exception(e);
228173
} catch (sycl::exception& e) {
229-
LOG(ERROR) << "DPC++ Exception: " << e.what() << ", file = " << __FILE__
174+
LOG(ERROR) << "SYCL Exception: " << e.what() << ", file = " << __FILE__
230175
<< ", line = " << __LINE__ << ".";
231176
}
232177
}
233178
};
234179

235180
SYCLError_t SYCLStreamPool::getDefaultStream(sycl::device* device_handle,
236-
sycl::queue** stream_p) {
237-
*stream_p = SYCLStreamPool::GetStreamsPool(device_handle)[0].get();
238-
return SYCL_SUCCESS;
181+
sycl::queue** stream_p) {
182+
*stream_p = SYCLStreamPool::GetStreamsPool(device_handle)[0].get();
183+
return SYCL_SUCCESS;
239184
}
240185

241186
SYCLError_t SYCLStreamPool::createStream(sycl::device* device_handle,
242-
sycl::queue** stream_p) {
187+
sycl::queue** stream_p) {
243188
if (IsMultipleStreamEnabled()) {
244189
sycl::property_list propList{sycl::property::queue::enable_profiling(),
245-
sycl::property::queue::in_order()};
190+
sycl::property::queue::in_order()};
246191
SYCLStreamPool::GetStreamsPool(device_handle)
247-
.push_back(std::make_shared<sycl::queue>(
248-
DevicePool::getDeviceContext(), *device_handle, SYCLAsyncHandler,
249-
propList));
192+
.push_back(std::make_shared<sycl::queue>(DevicePool::getDeviceContext(),
193+
*device_handle,
194+
SYCLAsyncHandler, propList));
250195
}
251196
*stream_p = SYCLStreamPool::GetStreamsPool(device_handle).back().get();
252197
return SYCL_SUCCESS;
@@ -259,8 +204,8 @@ SYCLError_t SYCLStreamPool::syncContext(sycl::device* device_handle) {
259204
return SYCL_SUCCESS;
260205
}
261206

262-
SYCLError_t SYCLStreamPool::destroyStream(sycl::device* device_handle,
263-
sycl::queue* stream_handle) {
207+
SYCLError_t SYCLStreamPool::destroyStream(sycl::device* device_handle,
208+
sycl::queue* stream_handle) {
264209
if (stream_handle == nullptr) return SYCL_ERROR_INVALID_STREAM;
265210
auto stream_pool = SYCLStreamPool::GetStreamsPool(device_handle);
266211
for (int i = 0; i < stream_pool.size(); i++) {
@@ -280,7 +225,7 @@ std::vector<std::shared_ptr<sycl::queue>>& SYCLStreamPool::GetStreamsPool(
280225
auto iter = stream_pool_map.find(device_handle);
281226
if (iter != stream_pool_map.end()) return iter->second;
282227
sycl::property_list propList{sycl::property::queue::enable_profiling(),
283-
sycl::property::queue::in_order()};
228+
sycl::property::queue::in_order()};
284229
std::vector<std::shared_ptr<sycl::queue>> stream_pool = {
285230
std::make_shared<sycl::queue>(DevicePool::getDeviceContext(),
286231
*device_handle, SYCLAsyncHandler,
@@ -471,17 +416,17 @@ SYCLError_t SYCLMemsetD32Async(void* dstDevice, unsigned int ui, size_t N,
471416
}
472417

473418
SYCLError_t SYCLMemcpyAsync(void* dst, const void* src, size_t ByteCount,
474-
SYCLError_t (*func)(void*, const void*, size_t, sycl::queue*),
475-
sycl::queue* stream){
419+
SYCLError_t (*func)(void*, const void*, size_t,
420+
sycl::queue*),
421+
sycl::queue* stream) {
476422
return (*func)(dst, src, ByteCount, stream);
477423
}
478424

479-
SYCLError_t SYCLStreamSynchronize(sycl::queue* stream){
425+
SYCLError_t SYCLStreamSynchronize(sycl::queue* stream) {
480426
stream->wait();
481427
return SYCL_SUCCESS;
482428
}
483429

484-
485430
void* SYCLMalloc(sycl::device* device, size_t ByteCount) {
486431
sycl::queue* stream;
487432
SYCLStreamPool::getDefaultStream(device, &stream);
@@ -537,18 +482,18 @@ void SYCLStreamDependOnEvents(sycl::queue* stream,
537482
const char* ToString(SYCLError_t error) {
538483
switch (error) {
539484
case SYCL_SUCCESS:
540-
return "DPC++ succeed.";
485+
return "SYCL succeed.";
541486
case SYCL_ERROR_NO_DEVICE:
542-
return "DPC++ did not find the device.";
487+
return "SYCL did not find the device.";
543488
case SYCL_ERROR_INVALID_DEVICE:
544-
return "DPC++ got invalid device id.";
489+
return "SYCL got invalid device id.";
545490
case SYCL_ERROR_INVALID_POINTER:
546-
return "DPC++ got invalid pointer.";
491+
return "SYCL got invalid pointer.";
547492
case SYCL_ERROR_INVALID_STREAM:
548-
return "DPC++ got invalid stream.";
493+
return "SYCL got invalid stream.";
549494
case SYCL_ERROR_DESTROY_DEFAULT_STREAM:
550-
return "DPC++ cannot destroy default stream.";
495+
return "SYCL cannot destroy default stream.";
551496
default:
552-
return "DPC++ got invalid error code.";
497+
return "SYCL got invalid error code.";
553498
}
554499
} // namespace

xla/stream_executor/sycl/sycl_gpu_runtime.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,8 @@ limitations under the License.
2121

2222
#include "absl/strings/ascii.h"
2323

24-
#if __has_include(<sycl/sycl.hpp>)
25-
#include <sycl/sycl.hpp>
26-
#elif __has_include(<CL/sycl.hpp>)
27-
#include <CL/sycl.hpp>
28-
#else
29-
#error "Unsupported compiler"
30-
#endif
31-
3224
#include <level_zero/ze_api.h>
25+
#include <sycl/sycl.hpp>
3326

3427
enum SYCLError_t {
3528
SYCL_SUCCESS,
@@ -100,7 +93,8 @@ void* SYCLMallocShared(sycl::device* device, size_t ByteCount);
10093
void SYCLFree(sycl::device* device, void* ptr);
10194

10295
SYCLError_t SYCLMemcpyAsync(void* dst, const void* src, size_t ByteCount,
103-
SYCLError_t (*func)(void*, const void*, size_t, sycl::queue*),
96+
SYCLError_t (*func)(void*, const void*, size_t,
97+
sycl::queue*),
10498
sycl::queue* stream);
10599

106100
SYCLError_t SYCLStreamSynchronize(sycl::queue* stream);
@@ -121,6 +115,7 @@ class SYCLStreamPool {
121115
static SYCLError_t syncContext(sycl::device* device_handle);
122116
static SYCLError_t destroyStream(sycl::device* device_handle,
123117
sycl::queue* stream_handle);
118+
124119
private:
125120
static std::vector<std::shared_ptr<sycl::queue>>& GetStreamsPool(
126121
sycl::device* device_handle);

0 commit comments

Comments
 (0)