Skip to content

Commit 7506cc0

Browse files
committed
1. refactor Global_Func::find()
2. fix bug in Parallel_LRI_Equally_Filter::get_list_Ab2()
1 parent 35c750d commit 7506cc0

File tree

5 files changed

+36
-22
lines changed

5 files changed

+36
-22
lines changed

include/RI/global/Global_Func-1.h

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "Tensor.h"
99

10+
#include <vector>
1011
#include <map>
1112
#include <set>
1213
#include <array>
@@ -17,54 +18,67 @@ namespace RI
1718

1819
namespace Global_Func
1920
{
21+
template<typename T> const T ZERO{};
22+
2023
// tensor = find(m,i,j,k);
2124
// <=>
2225
// tensor = m.at(i).at(j).at(k);
2326
// Peize Lin add 2022.05.26
2427
template<typename Tkey, typename Tdata,
2528
typename std::enable_if<std::is_arithmetic<Tdata>::value,bool>::type=0>
26-
inline Tdata find(
29+
inline const Tdata &find(
2730
const std::map<Tkey, Tdata> &m,
2831
const Tkey &key)
2932
{
30-
const auto ptr = m.find(key);
33+
const auto &ptr = m.find(key);
3134
if(ptr==m.end())
32-
return 0;
35+
return ZERO<Tdata>;
3336
else
3437
return ptr->second;
3538
}
3639
template<typename Tkey, typename Tdata>
37-
inline Tensor<Tdata> find(
40+
inline const Tensor<Tdata> &find(
3841
const std::map<Tkey, Tensor<Tdata>> &m,
3942
const Tkey &key)
4043
{
41-
const auto ptr = m.find(key);
44+
const auto &ptr = m.find(key);
4245
if(ptr==m.end())
43-
return Tensor<Tdata>{};
46+
return ZERO<Tensor<Tdata>>;
4447
else
4548
return ptr->second;
4649
}
4750
template<typename Tkey, typename Tdata, std::size_t Ndim>
48-
inline std::array<Tdata,Ndim> find(
51+
inline const std::array<Tdata,Ndim> &find(
4952
const std::map<Tkey, std::array<Tdata,Ndim>> &m,
5053
const Tkey &key)
5154
{
52-
const auto ptr = m.find(key);
55+
const auto &ptr = m.find(key);
56+
if(ptr==m.end())
57+
return ZERO<std::array<Tdata,Ndim>>;
58+
else
59+
return ptr->second;
60+
}
61+
template<typename Tkey, typename Tdata>
62+
inline const std::vector<Tdata> &find(
63+
const std::map<Tkey, std::vector<Tdata>> &m,
64+
const Tkey &key)
65+
{
66+
const auto &ptr = m.find(key);
5367
if(ptr==m.end())
54-
return std::array<Tdata,Ndim>{};
68+
return ZERO<std::vector<Tdata>>;
5569
else
5670
return ptr->second;
5771
}
58-
template<typename Tkey, typename Tvalue, typename... Tkeys>
59-
inline auto find(
60-
const std::map<Tkey, Tvalue> &m,
61-
const Tkey &key,
72+
template<typename Tkey0, typename Tkey1, typename Tvalue, typename... Tkeys>
73+
inline const auto &find(
74+
const std::map<Tkey0, std::map<Tkey1,Tvalue>> &m,
75+
const Tkey0 &key0,
6276
const Tkeys&... keys)
6377
// -> decltype(find( m.find(key)->second, keys... )) // why error for C++ compiler high version
6478
{
65-
const auto ptr = m.find(key);
79+
const auto &ptr = m.find(key0);
6680
if(ptr==m.end())
67-
return decltype(find( ptr->second, keys... )){};
81+
return ZERO<typename std::remove_reference<decltype(find( ptr->second, keys... ))>::type>;
6882
else
6983
return find( ptr->second, keys... );
7084
}

include/RI/global/Tensor-multiply.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ Tensor<T> operator* (const Tensor<T> &t1, const Tensor<T> &t2)
5353
assert(t1.shape[0] == t2.shape[0]);
5454
return Blas_Interface::gemv('T', T(1), t2, t1);
5555
}
56-
default:;
57-
throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__));
56+
default:
57+
throw std::invalid_argument("Tensor multiply:\t"+std::to_string(t1.shape.size())+" * "+std::to_string(t2.shape.size())+".\n"+std::string(__FILE__)+" line "+std::to_string(__LINE__));
5858
}
5959
}
6060
case 2:
@@ -72,11 +72,11 @@ Tensor<T> operator* (const Tensor<T> &t1, const Tensor<T> &t2)
7272
return Blas_Interface::gemm('N', 'N', T(1), t1, t2);
7373
}
7474
default:
75-
throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__));
75+
throw std::invalid_argument("Tensor multiply:\t"+std::to_string(t1.shape.size())+" * "+std::to_string(t2.shape.size())+".\n"+std::string(__FILE__)+" line "+std::to_string(__LINE__));
7676
}
7777
}
7878
default:
79-
throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__));
79+
throw std::invalid_argument("Tensor multiply:\t"+std::to_string(t1.shape.size())+" * "+std::to_string(t2.shape.size())+".\n"+std::string(__FILE__)+" line "+std::to_string(__LINE__));
8080
}
8181
}
8282

include/RI/global/Tensor_Algorithm.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ namespace Tensor_Algorithm
9090
}
9191

9292
Tensor<T> em({m.shape[0], m.shape[1]});
93-
int ie=0;
94-
for( int i=0; i!=m.shape[0]; ++i )
93+
std::size_t ie=0;
94+
for( std::size_t i=0; i<m.shape[0]; ++i )
9595
if( std::abs(eigen_values[i]) >= threshold )
9696
{
9797
Blas_Interface::axpy(m.shape[1], T(1.0)/eigen_values[i], m.ptr()+i*m.shape[1], em.ptr()+ie*em.shape[1]);

include/RI/parallel/Parallel_LRI_Equally_Filter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Parallel_LRI_Equally_Filter: public Parallel_LRI_Equally<TA,Tcell,Ndim,Tda
2020

2121
const std::vector<TAC>& get_list_Ab2 (const TA &Aa01, const TAC &Aa2, const TAC &Ab01) const override
2222
{
23-
return this->list_Ab2_filter.at(Ab01);
23+
return Global_Func::find(this->list_Ab2_filter,Ab01);
2424
}
2525

2626
void filter_Ab2 (const std::map<TA, std::map<TAC, Tensor<Tdata>>> &Ds_b);
File renamed without changes.

0 commit comments

Comments
 (0)