99namespace cosma {
1010#ifdef COSMA_HAVE_GPU
1111template <typename Scalar>
12- gpu::mm_handle<Scalar>* cosma_context<Scalar>::get_gpu_context() {
12+ gpu::mm_handle<Scalar> * cosma_context<Scalar>::get_gpu_context() {
1313 return gpu_ctx_.get ();
1414}
1515#endif
@@ -21,26 +21,29 @@ cosma_context<Scalar>::cosma_context() {
2121 overlap_comm_and_comp = get_overlap_comm_and_comp ();
2222 pin_host_buffers = get_memory_pinning ();
2323#ifdef COSMA_HAVE_GPU
24- gpu_ctx_ = gpu::make_context<Scalar>(gpu_streams (),
25- gpu_max_tile_m (),
26- gpu_max_tile_n (),
27- gpu_max_tile_k ());
24+ gpu_ctx_ = gpu::make_context<Scalar>(
25+ gpu_streams (), gpu_max_tile_m (), gpu_max_tile_n (), gpu_max_tile_k ());
2826#endif
2927}
3028
3129template <typename Scalar>
32- cosma_context<Scalar>::cosma_context(size_t cpu_mem_limit, int streams, int tile_m, int tile_n, int tile_k) {
33- cpu_memory_limit = (long long ) cpu_mem_limit;
30+ cosma_context<Scalar>::cosma_context(size_t cpu_mem_limit,
31+ int streams,
32+ int tile_m,
33+ int tile_n,
34+ int tile_k) {
35+ cpu_memory_limit = (long long )cpu_mem_limit;
3436 adapt_to_scalapack_strategy = get_adapt_strategy ();
3537 overlap_comm_and_comp = get_overlap_comm_and_comp ();
3638 pin_host_buffers = get_memory_pinning ();
3739 memory_pool_.amortization = get_memory_pool_amortization ();
3840 // do not reserve nor resize the memory pool
3941 // let this just serve as the upper bound when creating a strategy
40- // because otherwise, it might reserve/resize to much more than the problem requires
41- // memory_pool_.resize(cpu_mem_limit);
42+ // because otherwise, it might reserve/resize to much more than the problem
43+ // requires memory_pool_.resize(cpu_mem_limit);
4244#ifdef COSMA_HAVE_GPU
4345 gpu_ctx_ = gpu::make_context<Scalar>(streams, tile_m, tile_n, tile_k);
46+ gpu_ctx_.use_unified_memory_ = cosma::get_unified_memory ();
4447#else
4548 std::cout << " Ignoring parameters in make_context. These parameters only "
4649 " used in the CPU version."
@@ -59,24 +62,30 @@ cosma_context<Scalar>::~cosma_context() {
5962}
6063
6164template <typename Scalar>
62- memory_pool<Scalar>& cosma_context<Scalar>::get_memory_pool() {
65+ memory_pool<Scalar> & cosma_context<Scalar>::get_memory_pool() {
6366 return memory_pool_;
6467}
6568
69+ template <typename Scalar>
70+ bool cosma_context<Scalar>::unified_memory() {
71+ return unified_memory_;
72+ }
73+
6674template <typename Scalar>
6775long long cosma_context<Scalar>::get_cpu_memory_limit() {
6876 return cpu_memory_limit;
6977}
7078
7179template <typename Scalar>
72- cosma::communicator* cosma_context<Scalar>::get_cosma_comm() {
80+ cosma::communicator * cosma_context<Scalar>::get_cosma_comm() {
7381 return prev_cosma_comm.get ();
7482}
7583
7684template <typename Scalar>
7785void cosma_context<Scalar>::register_state(MPI_Comm comm,
7886 const Strategy strategy) {
79- if (comm == MPI_COMM_NULL) return ;
87+ if (comm == MPI_COMM_NULL)
88+ return ;
8089
8190 int same_comm = 0 ;
8291
@@ -90,38 +99,31 @@ void cosma_context<Scalar>::register_state(MPI_Comm comm,
9099 MPI_Comm prev_comm = prev_cosma_comm->full_comm ();
91100 int comm_compare;
92101 MPI_Comm_compare (prev_comm, comm, &comm_compare);
93- same_comm = comm_compare == MPI_CONGRUENT ||
94- comm_compare == MPI_IDENT;
102+ same_comm = comm_compare == MPI_CONGRUENT || comm_compare == MPI_IDENT;
95103
96- bool same_strategy = strategy == prev_strategy;
104+ bool same_strategy = strategy == prev_strategy;
97105
98106 // if same_comm and same strategy -> reuse the communicators
99107 if (!same_comm || !same_strategy) {
100108 prev_strategy = strategy;
101109
102110 PE (preprocessing_communicators);
103- prev_cosma_comm = std::make_unique<cosma::communicator>(strategy, comm);
111+ prev_cosma_comm =
112+ std::make_unique<cosma::communicator>(strategy, comm);
104113 PL ();
105114
106- memory_pool_.unpin_all ();
107- memory_pool_.already_pinned = false ;
108- memory_pool_.resized = false ;
115+ memory_pool_.unpin_all ();
116+ memory_pool_.already_pinned = false ;
117+ memory_pool_.resized = false ;
109118 }
110119 }
111120
112121 // if this rank is not taking part in multiply, return
113122 // if (prev_cosma_comm->is_idle()) return;
114123
115124#ifdef COSMA_HAVE_GPU
116- if (
117- !prev_cosma_comm->is_idle ()
118- &&
119- !memory_pool_.resized
120- &&
121- same_comm
122- &&
123- strategy == prev_strategy
124- ) {
125+ if (!prev_cosma_comm->is_idle () && !memory_pool_.resized && same_comm &&
126+ strategy == prev_strategy) {
125127 memory_pool_.already_pinned = true ;
126128 }
127129#endif
@@ -139,8 +141,13 @@ context<Scalar> make_context() {
139141}
140142
141143template <typename Scalar>
142- context<Scalar> make_context (size_t cpu_mem_limit, int streams, int tile_m, int tile_n, int tile_k) {
143- return std::make_unique<cosma_context<Scalar>>(cpu_mem_limit, streams, tile_m, tile_n, tile_k);
144+ context<Scalar> make_context (size_t cpu_mem_limit,
145+ int streams,
146+ int tile_m,
147+ int tile_n,
148+ int tile_k) {
149+ return std::make_unique<cosma_context<Scalar>>(
150+ cpu_mem_limit, streams, tile_m, tile_n, tile_k);
144151}
145152
146153// Meyer's singleton, thread-safe in C++11, but not in C++03.
@@ -171,29 +178,29 @@ template context<zfloat> make_context();
171178template context<zdouble> make_context ();
172179
173180template context<float > make_context (size_t cpu_mem_limit,
174- int streams,
175- int tile_m,
176- int tile_n,
177- int tile_k);
181+ int streams,
182+ int tile_m,
183+ int tile_n,
184+ int tile_k);
178185template context<double > make_context (size_t cpu_mem_limit,
179- int streams,
180- int tile_m,
181- int tile_n,
182- int tile_k);
186+ int streams,
187+ int tile_m,
188+ int tile_n,
189+ int tile_k);
183190template context<zfloat> make_context (size_t cpu_mem_limit,
184- int streams,
185- int tile_m,
186- int tile_n,
187- int tile_k);
191+ int streams,
192+ int tile_m,
193+ int tile_n,
194+ int tile_k);
188195template context<zdouble> make_context (size_t cpu_mem_limit,
189- int streams,
190- int tile_m,
191- int tile_n,
192- int tile_k);
196+ int streams,
197+ int tile_m,
198+ int tile_n,
199+ int tile_k);
193200
194201// template instantiation for get_context_instance
195202template global_context<float > get_context_instance ();
196203template global_context<double > get_context_instance ();
197204template global_context<zfloat> get_context_instance ();
198205template global_context<zdouble> get_context_instance ();
199- }
206+ } // namespace cosma
0 commit comments