Skip to content

Commit 57d59b4

Browse files
authored
Merge #1882 Use event to handle distributed row gatherer
This PR allows to use event to handle distributed row gatherer. Event-based way can allow us to submit the next kernel without waiting for the row gatherer preparation, which leads better performance usually. Related PR: #1882
2 parents c18358c + 4fe9d4a commit 57d59b4

27 files changed

+892
-27
lines changed

benchmark/test/reference/distributed_solver.matrix.stdout

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"cg::initialize": 1.0,
2828
"advanced_apply(<typename>)": 1.0,
2929
"dense::row_gather": 1.0,
30+
"event::record_event": 1.0,
3031
"csr::advanced_spmv": 1.0,
3132
"dense::compute_squared_norm2": 1.0,
3233
"dense::compute_sqrt": 1.0,

benchmark/test/reference/distributed_solver.profile.stderr

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ DEBUG: end cg::initialize
141141
DEBUG: begin advanced_apply(<typename>)
142142
DEBUG: begin dense::row_gather
143143
DEBUG: end dense::row_gather
144+
DEBUG: begin event::record_event
145+
DEBUG: end event::record_event
144146
DEBUG: begin advanced_apply(<typename>)
145147
DEBUG: begin csr::advanced_spmv
146148
DEBUG: end csr::advanced_spmv
@@ -181,6 +183,8 @@ DEBUG: end cg::step_1
181183
DEBUG: begin apply(<typename>)
182184
DEBUG: begin dense::row_gather
183185
DEBUG: end dense::row_gather
186+
DEBUG: begin event::record_event
187+
DEBUG: end event::record_event
184188
DEBUG: begin apply(<typename>)
185189
DEBUG: begin csr::spmv
186190
DEBUG: end csr::spmv
@@ -221,6 +225,8 @@ DEBUG: end cg::step_1
221225
DEBUG: begin apply(<typename>)
222226
DEBUG: begin dense::row_gather
223227
DEBUG: end dense::row_gather
228+
DEBUG: begin event::record_event
229+
DEBUG: end event::record_event
224230
DEBUG: begin apply(<typename>)
225231
DEBUG: begin csr::spmv
226232
DEBUG: end csr::spmv
@@ -261,6 +267,8 @@ DEBUG: end cg::step_1
261267
DEBUG: begin apply(<typename>)
262268
DEBUG: begin dense::row_gather
263269
DEBUG: end dense::row_gather
270+
DEBUG: begin event::record_event
271+
DEBUG: end event::record_event
264272
DEBUG: begin apply(<typename>)
265273
DEBUG: begin csr::spmv
266274
DEBUG: end csr::spmv
@@ -301,6 +309,8 @@ DEBUG: end cg::step_1
301309
DEBUG: begin apply(<typename>)
302310
DEBUG: begin dense::row_gather
303311
DEBUG: end dense::row_gather
312+
DEBUG: begin event::record_event
313+
DEBUG: end event::record_event
304314
DEBUG: begin apply(<typename>)
305315
DEBUG: begin csr::spmv
306316
DEBUG: end csr::spmv
@@ -341,6 +351,8 @@ DEBUG: end cg::step_1
341351
DEBUG: begin apply(<typename>)
342352
DEBUG: begin dense::row_gather
343353
DEBUG: end dense::row_gather
354+
DEBUG: begin event::record_event
355+
DEBUG: end event::record_event
344356
DEBUG: begin apply(<typename>)
345357
DEBUG: begin csr::spmv
346358
DEBUG: end csr::spmv
@@ -381,6 +393,8 @@ DEBUG: end cg::step_1
381393
DEBUG: begin apply(<typename>)
382394
DEBUG: begin dense::row_gather
383395
DEBUG: end dense::row_gather
396+
DEBUG: begin event::record_event
397+
DEBUG: end event::record_event
384398
DEBUG: begin apply(<typename>)
385399
DEBUG: begin csr::spmv
386400
DEBUG: end csr::spmv
@@ -421,6 +435,8 @@ DEBUG: end cg::step_1
421435
DEBUG: begin apply(<typename>)
422436
DEBUG: begin dense::row_gather
423437
DEBUG: end dense::row_gather
438+
DEBUG: begin event::record_event
439+
DEBUG: end event::record_event
424440
DEBUG: begin apply(<typename>)
425441
DEBUG: begin csr::spmv
426442
DEBUG: end csr::spmv
@@ -470,6 +486,8 @@ DEBUG: end cg::initialize
470486
DEBUG: begin advanced_apply(<typename>)
471487
DEBUG: begin dense::row_gather
472488
DEBUG: end dense::row_gather
489+
DEBUG: begin event::record_event
490+
DEBUG: end event::record_event
473491
DEBUG: begin advanced_apply(<typename>)
474492
DEBUG: begin csr::advanced_spmv
475493
DEBUG: end csr::advanced_spmv
@@ -508,6 +526,8 @@ DEBUG: end cg::step_1
508526
DEBUG: begin apply(<typename>)
509527
DEBUG: begin dense::row_gather
510528
DEBUG: end dense::row_gather
529+
DEBUG: begin event::record_event
530+
DEBUG: end event::record_event
511531
DEBUG: begin apply(<typename>)
512532
DEBUG: begin csr::spmv
513533
DEBUG: end csr::spmv
@@ -548,6 +568,8 @@ DEBUG: end cg::step_1
548568
DEBUG: begin apply(<typename>)
549569
DEBUG: begin dense::row_gather
550570
DEBUG: end dense::row_gather
571+
DEBUG: begin event::record_event
572+
DEBUG: end event::record_event
551573
DEBUG: begin apply(<typename>)
552574
DEBUG: begin csr::spmv
553575
DEBUG: end csr::spmv
@@ -588,6 +610,8 @@ DEBUG: end cg::step_1
588610
DEBUG: begin apply(<typename>)
589611
DEBUG: begin dense::row_gather
590612
DEBUG: end dense::row_gather
613+
DEBUG: begin event::record_event
614+
DEBUG: end event::record_event
591615
DEBUG: begin apply(<typename>)
592616
DEBUG: begin csr::spmv
593617
DEBUG: end csr::spmv
@@ -628,6 +652,8 @@ DEBUG: end cg::step_1
628652
DEBUG: begin apply(<typename>)
629653
DEBUG: begin dense::row_gather
630654
DEBUG: end dense::row_gather
655+
DEBUG: begin event::record_event
656+
DEBUG: end event::record_event
631657
DEBUG: begin apply(<typename>)
632658
DEBUG: begin csr::spmv
633659
DEBUG: end csr::spmv
@@ -668,6 +694,8 @@ DEBUG: end cg::step_1
668694
DEBUG: begin apply(<typename>)
669695
DEBUG: begin dense::row_gather
670696
DEBUG: end dense::row_gather
697+
DEBUG: begin event::record_event
698+
DEBUG: end event::record_event
671699
DEBUG: begin apply(<typename>)
672700
DEBUG: begin csr::spmv
673701
DEBUG: end csr::spmv
@@ -708,6 +736,8 @@ DEBUG: end cg::step_1
708736
DEBUG: begin apply(<typename>)
709737
DEBUG: begin dense::row_gather
710738
DEBUG: end dense::row_gather
739+
DEBUG: begin event::record_event
740+
DEBUG: end event::record_event
711741
DEBUG: begin apply(<typename>)
712742
DEBUG: begin csr::spmv
713743
DEBUG: end csr::spmv
@@ -748,6 +778,8 @@ DEBUG: end cg::step_1
748778
DEBUG: begin apply(<typename>)
749779
DEBUG: begin dense::row_gather
750780
DEBUG: end dense::row_gather
781+
DEBUG: begin event::record_event
782+
DEBUG: end event::record_event
751783
DEBUG: begin apply(<typename>)
752784
DEBUG: begin csr::spmv
753785
DEBUG: end csr::spmv
@@ -789,6 +821,8 @@ DEBUG: end copy(<typename>)
789821
DEBUG: begin advanced_apply(<typename>)
790822
DEBUG: begin dense::row_gather
791823
DEBUG: end dense::row_gather
824+
DEBUG: begin event::record_event
825+
DEBUG: end event::record_event
792826
DEBUG: begin advanced_apply(<typename>)
793827
DEBUG: begin csr::advanced_spmv
794828
DEBUG: end csr::advanced_spmv

benchmark/test/reference/distributed_solver.simple.stdout

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"cg::initialize": 1.0,
3030
"advanced_apply(<typename>)": 1.0,
3131
"dense::row_gather": 1.0,
32+
"event::record_event": 1.0,
3233
"csr::advanced_spmv": 1.0,
3334
"dense::compute_squared_norm2": 1.0,
3435
"dense::compute_sqrt": 1.0,

benchmark/test/reference/distributed_solver_dcomplex.simple.stdout

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"cg::initialize": 1.0,
3030
"advanced_apply(<typename>)": 1.0,
3131
"dense::row_gather": 1.0,
32+
"event::record_event": 1.0,
3233
"csr::advanced_spmv": 1.0,
3334
"dense::compute_squared_norm2": 1.0,
3435
"dense::compute_sqrt": 1.0,

benchmark/test/reference/spmv_distributed.profile.stderr

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ DEBUG: begin repetition
126126
DEBUG: begin apply(<typename>)
127127
DEBUG: begin dense::row_gather
128128
DEBUG: end dense::row_gather
129+
DEBUG: begin event::record_event
130+
DEBUG: end event::record_event
129131
DEBUG: begin apply(<typename>)
130132
DEBUG: begin csr::spmv
131133
DEBUG: end csr::spmv

common/unified/components/fill_array_kernels.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ void fill_array(std::shared_ptr<const DefaultExecutor> exec, ValueType* array,
2727

2828
GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_FILL_ARRAY_KERNEL);
2929
template GKO_DECLARE_FILL_ARRAY_KERNEL(bool);
30+
template GKO_DECLARE_FILL_ARRAY_KERNEL(char);
3031
template GKO_DECLARE_FILL_ARRAY_KERNEL(uint16);
3132
template GKO_DECLARE_FILL_ARRAY_KERNEL(uint32);
3233
#ifndef GKO_SIZE_T_IS_UINT64_T

core/base/array.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ ValueType reduce_add(const array<ValueType>& input_arr,
9292

9393
GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_ARRAY_FILL);
9494
template GKO_DECLARE_ARRAY_FILL(bool);
95+
template GKO_DECLARE_ARRAY_FILL(char);
9596
template GKO_DECLARE_ARRAY_FILL(uint16);
9697
template GKO_DECLARE_ARRAY_FILL(uint32);
9798
#ifndef GKO_SIZE_T_IS_UINT64_T

core/base/event.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#ifndef GKO_CORE_BASE_EVENT_HPP_
6+
#define GKO_CORE_BASE_EVENT_HPP_
7+
8+
#include <memory>
9+
10+
#include <ginkgo/core/base/event.hpp>
11+
#include <ginkgo/core/base/executor.hpp>
12+
13+
14+
namespace gko {
15+
namespace detail {
16+
17+
/**
18+
* NotAsyncEvent is to provide an Event implementation on unsupported executor
19+
* like reference. It will ensure the kernels are finished when recording this
20+
* event.
21+
*/
22+
class NotAsyncEvent : public Event {
23+
public:
24+
NotAsyncEvent(std::shared_ptr<const Executor> exec) { exec->synchronize(); }
25+
26+
void synchronize() const override
27+
{
28+
// we have sync in the recording phase
29+
}
30+
};
31+
32+
33+
} // namespace detail
34+
} // namespace gko
35+
36+
37+
#endif // #ifndef GKO_CORE_BASE_EVENT_HPP_

core/base/event_kernels.hpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
2+
//
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#ifndef GKO_CORE_BASE_EVENT_KERNELS_HPP_
6+
#define GKO_CORE_BASE_EVENT_KERNELS_HPP_
7+
8+
9+
#include <memory>
10+
11+
#include <ginkgo/core/base/event.hpp>
12+
#include <ginkgo/core/base/executor.hpp>
13+
14+
#include "core/base/kernel_declaration.hpp"
15+
16+
17+
namespace gko {
18+
namespace kernels {
19+
20+
21+
#define GKO_DECLARE_EVENT_RECORD_EVENT \
22+
void record_event(std::shared_ptr<const DefaultExecutor> exec, \
23+
std::shared_ptr<const detail::Event>& event)
24+
25+
26+
#define GKO_DECLARE_ALL_AS_TEMPLATES GKO_DECLARE_EVENT_RECORD_EVENT
27+
28+
29+
GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(event, GKO_DECLARE_ALL_AS_TEMPLATES);
30+
31+
32+
#undef GKO_DECLARE_ALL_AS_TEMPLATES
33+
34+
35+
} // namespace kernels
36+
} // namespace gko
37+
38+
#endif // GKO_CORE_BASE_EVENT_KERNELS_HPP_

core/device_hooks/common_kernels.inc.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "core/base/batch_instantiation.hpp"
1111
#include "core/base/batch_multi_vector_kernels.hpp"
1212
#include "core/base/device_matrix_data_kernels.hpp"
13+
#include "core/base/event_kernels.hpp"
1314
#include "core/base/index_set_kernels.hpp"
1415
#include "core/base/mixed_precision_types.hpp"
1516
#include "core/components/absolute_array_kernels.hpp"
@@ -253,6 +254,7 @@ template GKO_DECLARE_PREFIX_SUM_NONNEGATIVE_KERNEL(size_type);
253254

254255
GKO_STUB_TEMPLATE_TYPE(GKO_DECLARE_FILL_ARRAY_KERNEL);
255256
template GKO_DECLARE_FILL_ARRAY_KERNEL(bool);
257+
template GKO_DECLARE_FILL_ARRAY_KERNEL(char);
256258
template GKO_DECLARE_FILL_ARRAY_KERNEL(uint16);
257259
template GKO_DECLARE_FILL_ARRAY_KERNEL(uint32);
258260
#ifndef GKO_SIZE_T_IS_UINT64_T
@@ -316,6 +318,15 @@ GKO_STUB_INDEX_TYPE(GKO_DECLARE_INDEX_SET_LOCAL_TO_GLOBAL_KERNEL);
316318
} // namespace idx_set
317319

318320

321+
namespace event {
322+
323+
324+
GKO_STUB(GKO_DECLARE_EVENT_RECORD_EVENT);
325+
326+
327+
}
328+
329+
319330
namespace partition {
320331

321332

0 commit comments

Comments
 (0)