Skip to content

Commit bbf50c4

Browse files
committed
Allow cases for copying device_global with different template arguments
Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 2df6194 commit bbf50c4

File tree

4 files changed

+112
-9
lines changed

4 files changed

+112
-9
lines changed

sycl/doc/extensions/experimental/sycl_ext_oneapi_device_global.asciidoc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,11 @@ public:
247247
// Available if has_property<device_image_scope> is false
248248
constexpr device_global(const device_global &other);
249249
250+
// Available if has_property<device_image_scope> is false and OtherT is
251+
//convertible to T
252+
template <typename OtherT, typename OtherProps>
253+
constexpr device_global(const device_global<OtherT, OtherProps> &other) {}
254+
250255
device_global(const device_global &&) = delete;
251256
device_global &operator=(const device_global &) = delete;
252257
device_global &operator=(const device_global &&) = delete;
@@ -341,6 +346,22 @@ The storage on each device for `T` is initialized with a copy of the storage in
341346

342347
`T` must be copy constructible and trivially destructible.
343348

349+
// --- ROW BREAK ---
350+
a|
351+
[source,c++]
352+
----
353+
template <typename OtherT, typename OtherProps>
354+
constexpr device_global(const device_global<OtherT, OtherProps> &other) {}
355+
----
356+
|
357+
Available if `has_property<device_image_scope> == false`.
358+
359+
Constructs a `device_global` object, and implicit storage for `T` in the global address space on each device that may access it.
360+
361+
The storage on each device for `T` is initialized with a storage in `other`.
362+
363+
`OtherT` must be convertible to `T` and `T` must be trivially destructible.
364+
344365
// --- ROW BREAK ---
345366
a|
346367
[source,c++]

sycl/include/sycl/ext/oneapi/device_global/device_global.hpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ struct HasArrowOperator<T,
4949
std::void_t<decltype(std::declval<T>().operator->())>>
5050
: std::true_type {};
5151

52+
// Checks that T is a reference to either device_global or
53+
// device_global_base. This is used by the variadic ctor to allow copy ctors to
54+
// take preference.
55+
template <typename T>
56+
struct IsDeviceGlobalOrBaseRef : std::false_type {};
57+
5258
// Base class for device_global.
5359
template <typename T, typename PropertyListT, typename = void>
5460
class device_global_base {
@@ -65,16 +71,31 @@ class device_global_base {
6571

6672
public:
6773
#if __cpp_consteval
68-
template <typename... Args>
74+
// The SFINAE is to allow the copy constructors to take priority.
75+
template <
76+
typename... Args,
77+
std::enable_if_t<
78+
sizeof...(Args) != 1 ||
79+
(!IsDeviceGlobalOrBaseRef<std::remove_cv_t<Args>>::value && ...),
80+
int> = 0>
6981
consteval explicit device_global_base(Args &&...args) : init_val{args...} {}
7082
#else
7183
device_global_base() = default;
7284
#endif // __cpp_consteval
7385

7486
#ifndef __SYCL_DEVICE_ONLY__
87+
template <typename OtherT, typename OtherProps,
88+
typename = std::enable_if_t<std::is_convertible_v<OtherT, T>>>
89+
constexpr device_global_base(
90+
const device_global_base<OtherT, OtherProps> &DGB)
91+
: init_val{DGB.init_val} {}
7592
constexpr device_global_base(const device_global_base &DGB)
7693
: init_val{DGB.init_val} {}
7794
#else
95+
template <typename OtherT, typename OtherProps,
96+
typename = std::enable_if_t<std::is_convertible_v<OtherT, T>>>
97+
constexpr device_global_base(const device_global_base<OtherT, OtherProps> &) {
98+
}
7899
constexpr device_global_base(const device_global_base &) {}
79100
#endif // __SYCL_DEVICE_ONLY__
80101

@@ -109,12 +130,22 @@ class device_global_base<
109130

110131
public:
111132
#if __cpp_consteval
112-
template <typename... Args>
133+
// The SFINAE is to allow the copy constructors to take priority.
134+
template <
135+
typename... Args,
136+
std::enable_if_t<
137+
sizeof...(Args) != 1 ||
138+
(!IsDeviceGlobalOrBaseRef<std::remove_cv_t<Args>>::value && ...),
139+
int> = 0>
113140
consteval explicit device_global_base(Args &&...args) : val{args...} {}
114141
#else
115142
device_global_base() = default;
116143
#endif // __cpp_consteval
117144

145+
template <typename OtherT, typename OtherProps,
146+
typename = std::enable_if_t<std::is_convertible_v<OtherT, T>>>
147+
constexpr device_global_base(const device_global_base<OtherT, OtherProps> &) =
148+
delete;
118149
constexpr device_global_base(const device_global_base &) = delete;
119150

120151
template <access::decorated IsDecorated>
@@ -133,6 +164,11 @@ class device_global_base<
133164
const T>(this->get_ptr());
134165
}
135166
};
167+
168+
template <typename T, typename PropertyListT>
169+
struct IsDeviceGlobalOrBaseRef<const device_global_base<T, PropertyListT> &>
170+
: std::true_type {};
171+
136172
} // namespace detail
137173

138174
template <typename T, typename PropertyListT = empty_properties_t>
@@ -255,6 +291,12 @@ class
255291
}
256292
};
257293

294+
namespace detail {
295+
template <typename T, typename PropertyListT>
296+
struct IsDeviceGlobalOrBaseRef<device_global<T, PropertyListT> &>
297+
: std::true_type {};
298+
} // namespace detail
299+
258300
} // namespace ext::oneapi::experimental
259301
} // namespace _V1
260302
} // namespace sycl

sycl/test-e2e/DeviceGlobal/device_global_copy.cpp

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,55 @@
1111

1212
namespace oneapiext = sycl::ext::oneapi::experimental;
1313

14-
oneapiext::device_global<const int> DGInit{3};
15-
oneapiext::device_global<const int> DGCopy{DGInit};
14+
oneapiext::device_global<const int> DGInit1{3};
15+
oneapiext::device_global<const int> DGCopy1{DGInit1};
16+
17+
oneapiext::device_global<int> DGInit2{4};
18+
oneapiext::device_global<int> DGCopy2{DGInit2};
19+
20+
oneapiext::device_global<float> DGInit3{5.0f};
21+
oneapiext::device_global<int> DGCopy3{DGInit3};
22+
23+
oneapiext::device_global<const int, decltype(oneapiext::properties{oneapiext::device_image_scope})> DGInit4{6};
24+
oneapiext::device_global<const int> DGCopy4{DGInit4};
25+
26+
oneapiext::device_global<const int> DGInit5{7};
27+
oneapiext::device_global<const int, decltype(oneapiext::properties{oneapiext::host_access_read})> DGCopy5{DGInit5};
1628

1729
int main() {
1830
sycl::queue Q;
1931

2032
int ReadVals[2] = {0, 0};
2133
{
22-
sycl::buffer<int, 1> ReadValsBuff{ReadVals, 2};
34+
sycl::buffer<int, 10> ReadValsBuff{ReadVals, 2};
2335

2436
Q.submit([&](sycl::handler &CGH) {
2537
sycl::accessor ReadValsAcc{ReadValsBuff, CGH, sycl::write_only};
2638
CGH.single_task([=]() {
27-
ReadValsAcc[0] = DGInit.get();
28-
ReadValsAcc[1] = DGCopy.get();
39+
ReadValsAcc[0] = DGInit1.get();
40+
ReadValsAcc[1] = DGCopy1.get();
41+
ReadValsAcc[2] = DGInit2.get();
42+
ReadValsAcc[3] = DGCopy2.get();
43+
ReadValsAcc[4] = DGInit3.get();
44+
ReadValsAcc[5] = DGCopy3.get();
45+
ReadValsAcc[6] = DGInit4.get();
46+
ReadValsAcc[7] = DGCopy4.get();
47+
ReadValsAcc[8] = DGInit5.get();
48+
ReadValsAcc[9] = DGCopy5.get();
2949
});
3050
}).wait_and_throw();
3151
}
3252

3353
assert(ReadVals[0] == 3);
3454
assert(ReadVals[1] == 3);
55+
assert(ReadVals[2] == 4);
56+
assert(ReadVals[3] == 4);
57+
assert(ReadVals[4] == 5);
58+
assert(ReadVals[5] == 5);
59+
assert(ReadVals[6] == 6);
60+
assert(ReadVals[7] == 6);
61+
assert(ReadVals[8] == 7);
62+
assert(ReadVals[9] == 7);
3563

3664
return 0;
3765
}

sycl/test/extensions/device_global/device_global_copy_negative.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,21 @@ namespace oneapiext = sycl::ext::oneapi::experimental;
1010
using device_image_properties =
1111
decltype(oneapiext::properties{oneapiext::device_image_scope});
1212

13-
oneapiext::device_global<const int, device_image_properties> DGInit{3};
14-
oneapiext::device_global<const int, device_image_properties> DGCopy{DGInit};
13+
// expected-error@sycl/ext/oneapi/device_global/device_global.hpp:* {{call to deleted constructor}}
14+
oneapiext::device_global<const int, device_image_properties> DGInit1{3};
15+
oneapiext::device_global<const int, device_image_properties> DGCopy1{DGInit1};
1516

1617
// expected-error@sycl/ext/oneapi/device_global/device_global.hpp:* {{call to deleted constructor}}
18+
oneapiext::device_global<int, device_image_properties> DGInit2{3};
19+
oneapiext::device_global<int, device_image_properties> DGCopy2{DGInit2};
20+
21+
// expected-error@+2 {{call to deleted constructor}}
22+
oneapiext::device_global<int, device_image_properties> DGInit3{3};
23+
oneapiext::device_global<float, device_image_properties> DGCopy3{DGInit3};
24+
25+
// expected-error@+2 {{call to deleted constructor}}
26+
oneapiext::device_global<const int> DGInit4{3};
27+
oneapiext::device_global<const int, device_image_properties> DGCopy4{DGInit4};
28+
1729

1830
int main() { return 0; }

0 commit comments

Comments
 (0)