Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit b13570c

Browse files
committed
Added try_wait options
1 parent 8a3d67e commit b13570c

File tree

2 files changed

+40
-25
lines changed

2 files changed

+40
-25
lines changed

include/cuda/std/barrier

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ class barrier : public std::__barrier_base<_CompletionF, _Sco> {
4040
template<thread_scope>
4141
friend class pipeline;
4242

43-
using std::__barrier_base<_CompletionF, _Sco>::__try_wait;
44-
4543
public:
4644
barrier() = default;
4745

@@ -77,6 +75,13 @@ _LIBCUDACXX_END_NAMESPACE_CUDA_DEVICE
7775

7876
_LIBCUDACXX_BEGIN_NAMESPACE_CUDA
7977

78+
template<class __Barrier>
79+
inline _LIBCUDACXX_INLINE_VISIBILITY
80+
bool barrier_try_wait_parity(__Barrier const* __this, bool __parity)
81+
{
82+
return __this->__try_wait_parity(__parity);
83+
}
84+
8085
template<class __Barrier>
8186
struct __barrier_poll_tester_parity {
8287
__Barrier const* __this;
@@ -91,15 +96,15 @@ struct __barrier_poll_tester_parity {
9196
inline _LIBCUDACXX_INLINE_VISIBILITY
9297
bool operator()() const
9398
{
94-
return __this->__try_wait_parity(__parity);
99+
return barrier_try_wait_parity(__this, __parity);
95100
}
96101
};
97102

98103
template<class __Barrier>
99104
inline _LIBCUDACXX_INLINE_VISIBILITY
100-
void barrier_wait_for_parity(__Barrier const* __self, bool __parity)
105+
void barrier_wait_parity(__Barrier const* __this, bool __parity)
101106
{
102-
_CUDA_VSTD::__libcpp_thread_poll_with_backoff(__barrier_poll_tester_parity<__Barrier>(__self, __parity));
107+
_CUDA_VSTD::__libcpp_thread_poll_with_backoff(__barrier_poll_tester_parity<__Barrier>(__this, __parity));
103108
}
104109

105110
template<>
@@ -114,7 +119,7 @@ public:
114119
using arrival_token = typename __barrier_base::arrival_token;
115120

116121
_LIBCUDACXX_INLINE_VISIBILITY
117-
bool __try_wait(arrival_token __phase) const {
122+
bool try_wait(arrival_token __phase) const {
118123
#if __CUDA_ARCH__ >= 800
119124
if (__isShared(&__barrier)) {
120125
int __ready = 0;
@@ -131,7 +136,7 @@ public:
131136
else
132137
#endif
133138
{
134-
return __barrier.__try_wait(std::move(__phase));
139+
return __barrier.try_wait(std::move(__phase));
135140
}
136141
}
137142

libcxx/include/barrier

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,12 @@ class __barrier_base {
209209
_LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_base<ptrdiff_t, _Sco> __expected, __arrived;
210210
_LIBCUDACXX_BARRIER_ALIGNMENTS _CompletionF __completion;
211211
_LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_base<bool, _Sco> __phase;
212+
213+
_LIBCUDACXX_INLINE_VISIBILITY
214+
bool __try_wait_phase(bool __old_phase) const
215+
{
216+
return __phase.load(memory_order_acquire) != __old_phase;
217+
}
212218
public:
213219
using arrival_token = bool;
214220

@@ -241,11 +247,15 @@ public:
241247
return __old_phase;
242248
}
243249
_LIBCUDACXX_INLINE_VISIBILITY
244-
bool __try_wait(arrival_token __old_phase) const
250+
bool try_wait(arrival_token __old) const
245251
{
246-
return __phase != __old_phase;
252+
return __try_wait_phase(__old);
253+
}
254+
_LIBCUDACXX_INLINE_VISIBILITY
255+
bool __try_wait_parity(bool __parity) const
256+
{
257+
return __try_wait_phase(__parity);
247258
}
248-
249259
_LIBCUDACXX_INLINE_VISIBILITY
250260
void wait(arrival_token&& __old_phase) const
251261
{
@@ -281,10 +291,10 @@ struct __barrier_poll_tester {
281291
, __phase(_CUDA_VSTD::move(__phase_))
282292
{}
283293

284-
inline _LIBCUDACXX_INLINE_VISIBILITY
294+
_LIBCUDACXX_INLINE_VISIBILITY
285295
bool operator()() const
286296
{
287-
return __this->__try_wait(__phase);
297+
return __this->try_wait(__phase);
288298
}
289299
};
290300

@@ -303,12 +313,18 @@ public:
303313
using arrival_token = uint64_t;
304314

305315
private:
306-
static inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR
316+
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR
307317
uint64_t __init(ptrdiff_t __count) _NOEXCEPT
308318
{
309319
return (((1u << 31) - __count) << 32)
310320
| ((1u << 31) - __count);
311321
}
322+
_LIBCUDACXX_INLINE_VISIBILITY
323+
bool __try_wait_phase(uint64_t __phase) const
324+
{
325+
uint64_t const __current = __phase_arrived_expected.load(memory_order_acquire);
326+
return ((__current & __phase_bit) != __phase);
327+
}
312328

313329
public:
314330
__barrier_base() = default;
@@ -323,19 +339,13 @@ public:
323339
__barrier_base(__barrier_base const&) = delete;
324340
__barrier_base& operator=(__barrier_base const&) = delete;
325341

326-
inline _LIBCUDACXX_INLINE_VISIBILITY
327-
bool __try_wait_phase(uint64_t __phase) const
328-
{
329-
uint64_t const __current = __phase_arrived_expected.load(memory_order_acquire);
330-
return ((__current & __phase_bit) != __phase);
331-
}
332-
inline _LIBCUDACXX_INLINE_VISIBILITY
342+
_LIBCUDACXX_INLINE_VISIBILITY
333343
bool __try_wait_parity(bool __parity) const
334344
{
335345
return __try_wait_phase(__parity ? __phase_bit : 0);
336346
}
337-
inline _LIBCUDACXX_INLINE_VISIBILITY
338-
bool __try_wait(arrival_token __old) const
347+
_LIBCUDACXX_INLINE_VISIBILITY
348+
bool try_wait(arrival_token __old) const
339349
{
340350
return __try_wait_phase(__old & __phase_bit);
341351
}
@@ -351,17 +361,17 @@ public:
351361
}
352362
return __old & __phase_bit;
353363
}
354-
inline _LIBCUDACXX_INLINE_VISIBILITY
364+
_LIBCUDACXX_INLINE_VISIBILITY
355365
void wait(arrival_token&& __phase) const
356366
{
357367
__libcpp_thread_poll_with_backoff(__barrier_poll_tester<__barrier_base<__empty_completion, _Sco>>(this, _CUDA_VSTD::move(__phase)));
358368
}
359-
inline _LIBCUDACXX_INLINE_VISIBILITY
369+
_LIBCUDACXX_INLINE_VISIBILITY
360370
void arrive_and_wait()
361371
{
362372
wait(arrive());
363373
}
364-
inline _LIBCUDACXX_INLINE_VISIBILITY
374+
_LIBCUDACXX_INLINE_VISIBILITY
365375
void arrive_and_drop()
366376
{
367377
__phase_arrived_expected.fetch_add(__expected_unit, memory_order_relaxed);

0 commit comments

Comments
 (0)