Skip to content

Commit 4b2f3c3

Browse files
committed
1. add namespace Communicate_Map_Combine
2. add LRI::set_tensors_map2(label_list)
1 parent c69fafc commit 4b2f3c3

19 files changed

+334
-88
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//=======================
2+
// AUTHOR : Peize Lin
3+
// DATE : 2023-08-31
4+
//=======================
5+
6+
#pragma once
7+
8+
#include "../../global/Cereal_Types.h"
9+
10+
#include <tuple>
11+
12+
namespace RI
13+
{
14+
15+
namespace Communicate_Map_Combine
16+
{
17+
template<typename TA, typename TC,
18+
template<typename T0, typename T1> class T_Judge_Map2_x,
19+
template<typename T0, typename T1> class T_Judge_Map2_y>
20+
class Judge_Map2_Combine2
21+
{
22+
using TAC = std::pair<TA,TC>;
23+
public:
24+
bool judge(const std::tuple<TA,TAC> &key) const
25+
{
26+
for(const auto &judge_i : std::get<0>(this->judge_list))
27+
if(judge_i.judge(key))
28+
return true;
29+
for(const auto &judge_i : std::get<1>(this->judge_list))
30+
if(judge_i.judge(key))
31+
return true;
32+
return false;
33+
}
34+
std::tuple<
35+
std::vector<T_Judge_Map2_x<TA,TAC>>,
36+
std::vector<T_Judge_Map2_y<TA,TC>> > judge_list;
37+
template <class Archive> void serialize( Archive & ar ){ ar(judge_list); }
38+
};
39+
}
40+
41+
}

include/RI/comm/mix/Communicate_Tensors_Map_Judge.h

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#include "../../global/Cereal_Types.h"
1010
#include <Comm/example/Communicate_Map-2.h>
1111
#include "../example/Communicate_Map_Period.h"
12+
#include "../example/Communicate_Map_Combine.h"
1213

13-
#include <map>
1414
#include <mpi.h>
15+
#include <map>
16+
#include <tuple>
1517

1618
namespace RI
1719
{
@@ -114,6 +116,42 @@ namespace Communicate_Tensors_Map_Judge
114116
judge.s2 = s2;
115117
return Communicate_Tensors_Map::comm_map3(mpi_comm, Ds_in, judge);
116118
}
119+
120+
121+
template<typename TA, typename TC, typename Tdata>
122+
std::map<TA,std::map<std::pair<TA,TC>,Tensor<Tdata>>>
123+
comm_map2_combine_origin_period(
124+
const MPI_Comm &mpi_comm,
125+
const std::map<TA,std::map<std::pair<TA,TC>,Tensor<Tdata>>> &Ds_in,
126+
const std::tuple<
127+
std::vector<std::tuple< std::set<TA>, std::set<std::pair<TA,TC>> >>,
128+
std::vector<std::tuple< std::set<std::pair<TA,TC>>, std::set<std::pair<TA,TC>> >>
129+
> &s_list,
130+
const TC &period)
131+
{
132+
Communicate_Map_Combine::Judge_Map2_Combine2<
133+
TA, TC,
134+
Comm::Communicate_Map::Judge_Map2,
135+
Communicate_Map_Period::Judge_Map2_Period
136+
> judge_combine;
137+
138+
std::get<0>(judge_combine.judge_list).resize(std::get<0>(s_list).size());
139+
for(int j=0; j<std::get<0>(s_list).size(); ++j)
140+
{
141+
std::get<0>(judge_combine.judge_list)[j].s0 = std::get<0>(std::get<0>(s_list)[j]);
142+
std::get<0>(judge_combine.judge_list)[j].s1 = std::get<1>(std::get<0>(s_list)[j]);
143+
}
144+
145+
std::get<1>(judge_combine.judge_list).resize(std::get<1>(s_list).size());
146+
for(int j=0; j<std::get<1>(s_list).size(); ++j)
147+
{
148+
std::get<1>(judge_combine.judge_list)[j].s0 = std::get<0>(std::get<1>(s_list)[j]);
149+
std::get<1>(judge_combine.judge_list)[j].s1 = std::get<1>(std::get<1>(s_list)[j]);
150+
std::get<1>(judge_combine.judge_list)[j].period = period;
151+
}
152+
153+
return Communicate_Tensors_Map::comm_map2(mpi_comm, Ds_in, judge_combine);
154+
}
117155
}
118156

119157
}

include/RI/global/Cereal_Types.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
#include <cereal/types/vector.hpp>
1313
#include <cereal/types/valarray.hpp>
1414
#include <cereal/types/map.hpp>
15-
#include <cereal/types/set.hpp>
15+
#include <cereal/types/set.hpp>
16+
#include <cereal/types/tuple.hpp>

include/RI/parallel/Parallel_LRI.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ class Parallel_LRI
3838
virtual std::map<TA,std::map<TAC,Tensor<Tdata>>> comm_tensors_map2(
3939
const Label::ab &label,
4040
const std::map<TA,std::map<TAC,Tensor<Tdata>>> &Ds) const =0;
41+
virtual std::map<TA,std::map<TAC,Tensor<Tdata>>> comm_tensors_map2(
42+
const std::vector<Label::ab> &label,
43+
const std::map<TA,std::map<TAC,Tensor<Tdata>>> &Ds) const =0;
4144

4245
virtual const std::vector<TA >& get_list_Aa01() const =0;
4346
virtual const std::vector<TAC>& get_list_Aa2 () const =0;

include/RI/parallel/Parallel_LRI_Equally.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ class Parallel_LRI_Equally: public Parallel_LRI<TA,Tcell,Ndim,Tdata>
2727
std::map<TA,std::map<TAC,Tensor<Tdata>>> comm_tensors_map2(
2828
const Label::ab &label,
2929
const std::map<TA,std::map<TAC,Tensor<Tdata>>> &Ds) const override;
30+
std::map<TA,std::map<TAC,Tensor<Tdata>>> comm_tensors_map2(
31+
const std::vector<Label::ab> &label,
32+
const std::map<TA,std::map<TAC,Tensor<Tdata>>> &Ds) const override;
3033

3134
const std::vector<TA >& get_list_Aa01() const override { return this->list_Aa01; }
3235
const std::vector<TAC>& get_list_Aa2 () const override { return this->list_Aa2; }

include/RI/parallel/Parallel_LRI_Equally.hpp

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ auto Parallel_LRI_Equally<TA,Tcell,Ndim,Tdata>::comm_tensors_map2(
5252
case Label::ab::a0b0: case Label::ab::a0b1:
5353
case Label::ab::a1b0: case Label::ab::a1b1:
5454
return Communicate_Tensors_Map_Judge::comm_map2(this->mpi_comm, Ds, Global_Func::to_set(this->list_Aa01), Global_Func::to_set(this->list_Ab01));
55-
case Label::ab::a0b2:
56-
case Label::ab::a1b2:
55+
case Label::ab::a0b2: case Label::ab::a1b2:
5756
return Communicate_Tensors_Map_Judge::comm_map2(this->mpi_comm, Ds, Global_Func::to_set(this->list_Aa01), Global_Func::to_set(this->list_Ab2));
5857
case Label::ab::a2b0: case Label::ab::a2b1:
5958
return Communicate_Tensors_Map_Judge::comm_map2_period(this->mpi_comm, Ds, Global_Func::to_set(this->list_Aa2), Global_Func::to_set(this->list_Ab01), this->period);
@@ -64,4 +63,48 @@ auto Parallel_LRI_Equally<TA,Tcell,Ndim,Tdata>::comm_tensors_map2(
6463
}
6564
}
6665

66+
template<typename TA, typename Tcell, std::size_t Ndim, typename Tdata>
67+
auto Parallel_LRI_Equally<TA,Tcell,Ndim,Tdata>::comm_tensors_map2(
68+
const std::vector<Label::ab> &label_list,
69+
const std::map<TA,std::map<TAC,Tensor<Tdata>>> &Ds) const
70+
-> std::map<TA,std::map<TAC,Tensor<Tdata>>>
71+
{
72+
std::tuple<
73+
std::vector<std::tuple< std::set<TA>, std::set<std::pair<TA,TC>> >>,
74+
std::vector<std::tuple< std::set<std::pair<TA,TC>>, std::set<std::pair<TA,TC>> >>
75+
> s_list;
76+
77+
std::vector<bool> flags(6, false);
78+
for(const Label::ab &label : label_list)
79+
{
80+
switch(label)
81+
{
82+
case Label::ab::a:
83+
flags[0]=true; break;
84+
case Label::ab::b:
85+
flags[1]=true; break;
86+
case Label::ab::a0b0: case Label::ab::a0b1:
87+
case Label::ab::a1b0: case Label::ab::a1b1:
88+
flags[2]=true; break;
89+
case Label::ab::a0b2: case Label::ab::a1b2:
90+
flags[3]=true; break;
91+
case Label::ab::a2b0: case Label::ab::a2b1:
92+
flags[4]=true; break;
93+
case Label::ab::a2b2:
94+
flags[5]=true; break;
95+
default:
96+
throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__));
97+
}
98+
}
99+
100+
if(flags[0]) std::get<0>(s_list).push_back(std::make_tuple( Global_Func::to_set(this->list_Aa01), Global_Func::to_set(this->list_Aa2) ));
101+
if(flags[1]) std::get<1>(s_list).push_back(std::make_tuple( Global_Func::to_set(this->list_Ab01), Global_Func::to_set(this->list_Ab2) ));
102+
if(flags[2]) std::get<0>(s_list).push_back(std::make_tuple( Global_Func::to_set(this->list_Aa01), Global_Func::to_set(this->list_Ab01) ));
103+
if(flags[3]) std::get<0>(s_list).push_back(std::make_tuple( Global_Func::to_set(this->list_Aa01), Global_Func::to_set(this->list_Ab2) ));
104+
if(flags[4]) std::get<1>(s_list).push_back(std::make_tuple( Global_Func::to_set(this->list_Aa2), Global_Func::to_set(this->list_Ab01) ));
105+
if(flags[5]) std::get<1>(s_list).push_back(std::make_tuple( Global_Func::to_set(this->list_Aa2), Global_Func::to_set(this->list_Ab2) ));
106+
107+
return Communicate_Tensors_Map_Judge::comm_map2_combine_origin_period(this->mpi_comm, Ds, s_list, this->period);
108+
}
109+
67110
}

0 commit comments

Comments
 (0)