Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit 0504703

Browse files
ogirouxwmaxey
authored andcommitted
Added parity waiting
1 parent 43d3f2a commit 0504703

File tree

2 files changed

+79
-40
lines changed

2 files changed

+79
-40
lines changed

include/cuda/std/barrier

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,31 @@ _LIBCUDACXX_END_NAMESPACE_CUDA_DEVICE
7777

7878
_LIBCUDACXX_BEGIN_NAMESPACE_CUDA
7979

80+
template<class __Barrier>
81+
struct __barrier_poll_tester_parity {
82+
__Barrier const* __this;
83+
bool __parity;
84+
85+
_LIBCUDACXX_INLINE_VISIBILITY
86+
__barrier_poll_tester_parity(__Barrier const* __this_, bool __parity_)
87+
: __this(__this_)
88+
, __parity(__parity_)
89+
{}
90+
91+
inline _LIBCUDACXX_INLINE_VISIBILITY
92+
bool operator()() const
93+
{
94+
return __this->__try_wait_parity(__parity);
95+
}
96+
};
97+
98+
template<class __Barrier>
99+
inline _LIBCUDACXX_INLINE_VISIBILITY
100+
void barrier_wait_for_parity(__Barrier const* __self, bool __parity)
101+
{
102+
_CUDA_VSTD::__libcpp_thread_poll_with_backoff(__barrier_poll_tester_parity<__Barrier>(__self, __parity));
103+
}
104+
80105
template<>
81106
class barrier<thread_scope_block, std::__empty_completion> : public __block_scope_barrier_base {
82107
using __barrier_base = std::__barrier_base<std::__empty_completion, (int)thread_scope_block>;
@@ -88,24 +113,6 @@ class barrier<thread_scope_block, std::__empty_completion> : public __block_scop
88113
public:
89114
using arrival_token = typename __barrier_base::arrival_token;
90115

91-
private:
92-
struct __poll_tester {
93-
barrier const* __this;
94-
arrival_token __phase;
95-
96-
_LIBCUDACXX_INLINE_VISIBILITY
97-
__poll_tester(barrier const* __this_, arrival_token&& __phase_)
98-
: __this(__this_)
99-
, __phase(_CUDA_VSTD::move(__phase_))
100-
{}
101-
102-
inline _LIBCUDACXX_INLINE_VISIBILITY
103-
bool operator()() const
104-
{
105-
return __this->__try_wait(__phase);
106-
}
107-
};
108-
109116
_LIBCUDACXX_INLINE_VISIBILITY
110117
bool __try_wait(arrival_token __phase) const {
111118
#if __CUDA_ARCH__ >= 800
@@ -131,7 +138,28 @@ private:
131138
template<thread_scope>
132139
friend class pipeline;
133140

134-
public:
141+
_LIBCUDACXX_INLINE_VISIBILITY
142+
bool __try_wait_parity(bool __parity) const {
143+
#if __CUDA_ARCH__ >= 800
144+
if (__isShared(&__barrier)) {
145+
int __ready = 0;
146+
asm volatile ("{\n\t"
147+
".reg .pred p;\n\t"
148+
"mbarrier.test_wait.parity.shared.b64 p, [%1], %2;\n\t"
149+
"selp.b32 %0, 1, 0, p;\n\t"
150+
"}"
151+
: "=r"(__ready)
152+
: "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier))), "r"(static_cast<std::uint32_t>(__parity))
153+
: "memory");
154+
return bool(__ready);
155+
}
156+
else
157+
#endif
158+
{
159+
return __barrier.__try_wait_parity(__parity);
160+
}
161+
}
162+
135163
barrier() = default;
136164

137165
barrier(const barrier &) = delete;
@@ -216,7 +244,7 @@ public:
216244
_LIBCUDACXX_INLINE_VISIBILITY
217245
void wait(arrival_token && __phase) const
218246
{
219-
_CUDA_VSTD::__libcpp_thread_poll_with_backoff(__poll_tester(this, _CUDA_VSTD::move(__phase)));
247+
_CUDA_VSTD::__libcpp_thread_poll_with_backoff(std::__barrier_poll_tester<barrier>(this, _CUDA_VSTD::move(__phase)));
220248
}
221249

222250
inline _LIBCUDACXX_INLINE_VISIBILITY

libcxx/include/barrier

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,24 @@ public:
270270
}
271271
};
272272

273+
template<class __Barrier>
274+
struct __barrier_poll_tester {
275+
__Barrier const* __this;
276+
typename __Barrier::arrival_token __phase;
277+
278+
_LIBCUDACXX_INLINE_VISIBILITY
279+
__barrier_poll_tester(__Barrier const* __this_, typename __Barrier::arrival_token&& __phase_)
280+
: __this(__this_)
281+
, __phase(_CUDA_VSTD::move(__phase_))
282+
{}
283+
284+
inline _LIBCUDACXX_INLINE_VISIBILITY
285+
bool operator()() const
286+
{
287+
return __this->__try_wait(__phase);
288+
}
289+
};
290+
273291
template<int _Sco>
274292
class __barrier_base<__empty_completion, _Sco> {
275293

@@ -285,23 +303,6 @@ public:
285303
using arrival_token = uint64_t;
286304

287305
private:
288-
struct __poll_tester {
289-
__barrier_base const* __this;
290-
arrival_token __phase;
291-
292-
_LIBCUDACXX_INLINE_VISIBILITY
293-
__poll_tester(__barrier_base const* __this_, arrival_token&& __phase_)
294-
: __this(__this_)
295-
, __phase(_CUDA_VSTD::move(__phase_))
296-
{}
297-
298-
inline _LIBCUDACXX_INLINE_VISIBILITY
299-
bool operator()() const
300-
{
301-
return __this->__try_wait(__phase);
302-
}
303-
};
304-
305306
static inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR
306307
uint64_t __init(ptrdiff_t __count) _NOEXCEPT
307308
{
@@ -322,12 +323,22 @@ public:
322323
__barrier_base(__barrier_base const&) = delete;
323324
__barrier_base& operator=(__barrier_base const&) = delete;
324325

325-
_LIBCUDACXX_INLINE_VISIBILITY
326-
bool __try_wait(arrival_token __phase) const
326+
inline _LIBCUDACXX_INLINE_VISIBILITY
327+
bool __try_wait_phase(uint64_t __phase) const
327328
{
328329
uint64_t const __current = __phase_arrived_expected.load(memory_order_acquire);
329330
return ((__current & __phase_bit) != __phase);
330331
}
332+
inline _LIBCUDACXX_INLINE_VISIBILITY
333+
bool __try_wait_parity(bool __parity) const
334+
{
335+
return __try_wait_phase(__parity ? __phase_bit : 0);
336+
}
337+
inline _LIBCUDACXX_INLINE_VISIBILITY
338+
bool __try_wait(arrival_token __old) const
339+
{
340+
return __try_wait_phase(__old & __phase_bit);
341+
}
331342

332343
_LIBCUDACXX_NODISCARD_ATTRIBUTE inline _LIBCUDACXX_INLINE_VISIBILITY
333344
arrival_token arrive(ptrdiff_t __update = 1)
@@ -343,7 +354,7 @@ public:
343354
inline _LIBCUDACXX_INLINE_VISIBILITY
344355
void wait(arrival_token&& __phase) const
345356
{
346-
__libcpp_thread_poll_with_backoff(__poll_tester(this, _CUDA_VSTD::move(__phase)));
357+
__libcpp_thread_poll_with_backoff(__barrier_poll_tester<__barrier_base<__empty_completion, _Sco>>(this, _CUDA_VSTD::move(__phase)));
347358
}
348359
inline _LIBCUDACXX_INLINE_VISIBILITY
349360
void arrive_and_wait()

0 commit comments

Comments
 (0)