@@ -77,6 +77,31 @@ _LIBCUDACXX_END_NAMESPACE_CUDA_DEVICE
7777
7878_LIBCUDACXX_BEGIN_NAMESPACE_CUDA
7979
80+ template<class __Barrier>
81+ struct __barrier_poll_tester_parity {
82+ __Barrier const* __this;
83+ bool __parity;
84+
85+ _LIBCUDACXX_INLINE_VISIBILITY
86+ __barrier_poll_tester_parity(__Barrier const* __this_, bool __parity_)
87+ : __this(__this_)
88+ , __parity(__parity_)
89+ {}
90+
91+ inline _LIBCUDACXX_INLINE_VISIBILITY
92+ bool operator()() const
93+ {
94+ return __this->__try_wait_parity(__parity);
95+ }
96+ };
97+
98+ template<class __Barrier>
99+ inline _LIBCUDACXX_INLINE_VISIBILITY
100+ void barrier_wait_for_parity(__Barrier const* __self, bool __parity)
101+ {
102+ _CUDA_VSTD::__libcpp_thread_poll_with_backoff(__barrier_poll_tester_parity<__Barrier>(__self, __parity));
103+ }
104+
80105template<>
81106class barrier<thread_scope_block, std::__empty_completion> : public __block_scope_barrier_base {
82107 using __barrier_base = std::__barrier_base<std::__empty_completion, (int)thread_scope_block>;
@@ -88,24 +113,6 @@ class barrier<thread_scope_block, std::__empty_completion> : public __block_scop
88113public:
89114 using arrival_token = typename __barrier_base::arrival_token;
90115
91- private:
92- struct __poll_tester {
93- barrier const* __this;
94- arrival_token __phase;
95-
96- _LIBCUDACXX_INLINE_VISIBILITY
97- __poll_tester(barrier const* __this_, arrival_token&& __phase_)
98- : __this(__this_)
99- , __phase(_CUDA_VSTD::move(__phase_))
100- {}
101-
102- inline _LIBCUDACXX_INLINE_VISIBILITY
103- bool operator()() const
104- {
105- return __this->__try_wait(__phase);
106- }
107- };
108-
109116 _LIBCUDACXX_INLINE_VISIBILITY
110117 bool __try_wait(arrival_token __phase) const {
111118#if __CUDA_ARCH__ >= 800
@@ -131,7 +138,28 @@ private:
131138 template<thread_scope>
132139 friend class pipeline;
133140
134- public:
141+ _LIBCUDACXX_INLINE_VISIBILITY
142+ bool __try_wait_parity(bool __parity) const {
143+ #if __CUDA_ARCH__ >= 800
144+ if (__isShared(&__barrier)) {
145+ int __ready = 0;
146+ asm volatile ("{\n\t"
147+ ".reg .pred p;\n\t"
148+ "mbarrier.test_wait.parity.shared.b64 p, [%1], %2;\n\t"
149+ "selp.b32 %0, 1, 0, p;\n\t"
150+ "}"
151+ : "=r"(__ready)
152+ : "r"(static_cast<std::uint32_t>(__cvta_generic_to_shared(&__barrier))), "r"(static_cast<std::uint32_t>(__parity))
153+ : "memory");
154+ return bool(__ready);
155+ }
156+ else
157+ #endif
158+ {
159+ return __barrier.__try_wait_parity(__parity);
160+ }
161+ }
162+
135163 barrier() = default;
136164
137165 barrier(const barrier &) = delete;
@@ -216,7 +244,7 @@ public:
216244 _LIBCUDACXX_INLINE_VISIBILITY
217245 void wait(arrival_token && __phase) const
218246 {
219- _CUDA_VSTD::__libcpp_thread_poll_with_backoff(__poll_tester (this, _CUDA_VSTD::move(__phase)));
247+ _CUDA_VSTD::__libcpp_thread_poll_with_backoff(std::__barrier_poll_tester<barrier> (this, _CUDA_VSTD::move(__phase)));
220248 }
221249
222250 inline _LIBCUDACXX_INLINE_VISIBILITY
0 commit comments