Skip to content

Commit d9f652e

Browse files
committed
Merge branch 'loop3-symmetry-new' into filter_exx
2 parents 9daa123 + 9228633 commit d9f652e

File tree

4 files changed

+135
-0
lines changed

4 files changed

+135
-0
lines changed

include/RI/global/Tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class Tensor
4545
inline T& operator() (const std::size_t i0, const std::size_t i1, const std::size_t i2, const std::size_t i3) const;
4646

4747
Tensor transpose() const;
48+
Tensor dagger() const;
4849

4950
// ||d||_p = (|d_1|^p+|d_2|^p+...)^{1/p}
5051
// if(p==std::numeric_limits<double>::max()) ||d||_max = max_i |d_i|

include/RI/global/Tensor.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,17 @@ Tensor<T> Tensor<T>::transpose() const
178178
return t;
179179
}
180180

181+
template<typename T>
182+
Tensor<T> Tensor<T>::dagger() const
183+
{
184+
assert(this->shape.size() == 2);
185+
Tensor<T> t({ this->shape[1], this->shape[0] });
186+
for (std::size_t i0 = 0; i0 < this->shape[0]; ++i0)
187+
for (std::size_t i1 = 0; i1 < this->shape[1]; ++i1)
188+
t(i1, i0) = std::conj((*this)(i0, i1));
189+
return t;
190+
}
191+
181192
template<typename T>
182193
Global_Func::To_Real_t<T> Tensor<T>::norm(const double p) const
183194
{
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include <array>
2+
#include <map>
3+
#include <set>
4+
#include <tuple>
5+
#define NO_SEC_RETURN_TRUE if(this->irreducible_sector_.empty()) return true;
6+
#include "../global/Array_Operator.h"
7+
namespace RI
8+
{
9+
using namespace Array_Operator;
10+
template<typename TA, typename Tcell, std::size_t Ndim, typename Tdata>
11+
class Symmetry_Filter
12+
{
13+
using TC = std::array<Tcell, Ndim>;
14+
using TAC = std::pair<TA, TC>;
15+
16+
using TIJ = std::pair<TA, TA>;
17+
using TIJR = std::pair<TIJ, TC>;
18+
using Tsec = std::map<TIJ, std::set<TC>>;
19+
public:
20+
Symmetry_Filter(const TC& period_in, const Tsec& irsec)
21+
:period(period_in), irreducible_sector_(irsec) {}
22+
bool in_irreducible_sector(const TA& Aa, const TAC& Ab) const
23+
{
24+
NO_SEC_RETURN_TRUE;
25+
const TIJ& ap = { Aa, Ab.first };
26+
if (irreducible_sector_.find(ap) != irreducible_sector_.end())
27+
if (irreducible_sector_.at(ap).find(Ab.second % this->period) != irreducible_sector_.at(ap).end())
28+
return true;
29+
return false;
30+
}
31+
bool in_irreducible_sector(const TAC& Aa, const TAC& Ab) const
32+
{
33+
NO_SEC_RETURN_TRUE;
34+
const TC dR = (Ab.second - Aa.second) % this->period;
35+
const std::pair<TA, TA> ap = { Aa.first, Ab.first };
36+
if (irreducible_sector_.find(ap) != irreducible_sector_.end())
37+
if (irreducible_sector_.at(ap).find(dR) != irreducible_sector_.at(ap).end())
38+
return true;
39+
return false;
40+
}
41+
bool is_I_in_irreducible_sector(const TA& Aa) const
42+
{
43+
NO_SEC_RETURN_TRUE;
44+
for (const auto& apRs : irreducible_sector_)
45+
if (apRs.first.first == Aa)return true;
46+
return false;
47+
}
48+
bool is_J_in_irreducible_sector(const TA& Ab) const
49+
{
50+
NO_SEC_RETURN_TRUE;
51+
for (const auto& apRs : irreducible_sector_)
52+
if (apRs.first.second == Ab)return true;
53+
return false;
54+
}
55+
TIJR get_IJR(const TA& I, const TAC& J) const
56+
{
57+
return { {I,J.first}, J.second % this->period };
58+
}
59+
TIJR get_IJR(const TAC& I, const TAC& J) const
60+
{
61+
return { {I.first,J.first}, (J.second - I.second) % this->period };
62+
}
63+
private:
64+
const Tsec& irreducible_sector_;
65+
const TC& period;
66+
};
67+
68+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include <RI/global/Tensor.h>
2+
namespace RI
3+
{
4+
namespace Sym
5+
{
6+
template <typename T>
7+
inline void T1_HR(T* TA, const T* A, const Tensor<T>& T1, const int& n2)
8+
{
9+
// C' = T1^\dagger * C
10+
const int& n1 = T1.shape[0];
11+
Blas_Interface::gemm('C', 'N', n1, n2, n1,
12+
T(1), T1.ptr(), n1, A, n2, T(0), TA, n2);
13+
// zgemm_(&notrans, &dagger, &n12, &nabf, &nabf,
14+
// &alpha, A, &n12, T1.ptr(), &nabf, &beta, TA, &n12);
15+
}
16+
17+
template <typename T>
18+
inline void T1_HR_T2(T* TAT, const T* A, const Tensor<T>& T1, const Tensor<T>& T2)
19+
{
20+
// H' = T1^\dagger * H * T2
21+
const int& n2 = T2.shape[0], & n1 = T1.shape[0];
22+
const RI::Shape_Vector& shape = { static_cast<size_t>(n1),static_cast<size_t>(n2) };
23+
RI::Tensor<T> AT2(shape);
24+
Blas_Interface::gemm('N', 'N', n1, n2, n2,
25+
T(1), A, n2, T2.ptr(), n2, T(0), AT2.ptr(), n2);
26+
Blas_Interface::gemm('C', 'N', n1, n2, n1,
27+
T(1), T1.ptr(), n1, AT2.ptr(), n2, T(0), TAT, n2);
28+
// col-major version
29+
// zgemm_(&notrans, &notrans, &n2, &n1, &n2,
30+
// &alpha, T2.ptr(), &n2, A, &n2, &beta, AT2.ptr(), &n2);
31+
// zgemm_(&notrans, &dagger, &n2, &n1, &n1,
32+
// &alpha, AT2.ptr(), &n2, T1.ptr(), &n1, &beta, TAT, &n2);
33+
}
34+
35+
template<typename T>
36+
inline void T1_DR_T2(T* TAT, const T* A, const Tensor<T>& T1, const Tensor<T>& T2)
37+
{
38+
// D' = T1^T * D * T2^* = T1^T * [T2^\dagger * D^T]^T
39+
const int& n2 = T2.shape[0], & n1 = T1.shape[0];
40+
const RI::Shape_Vector& shape = { static_cast<size_t>(n1),static_cast<size_t>(n2) };
41+
RI::Tensor<T> AT2(shape);
42+
BlasConnector::gemm('C', 'T', n2, n1, n2,
43+
T(1), T2.ptr(), n2, A, n2, T(0), AT2.ptr(), n1);
44+
BlasConnector::gemm('T', 'T', n1, n2, n1,
45+
T(1), T1.ptr(), n1, AT2.ptr(), n1, T(0), TAT, n2);
46+
// col-major version
47+
// zgemm_(&transpose, &dagger, &nw1, &nw2, &nw2,
48+
// &alpha, A, &nw2, T2.ptr(), &nw2, &beta, AT2.ptr(), &nw1);
49+
// zgemm_(&transpose, &transpose, &nw2, &nw1, &nw1,
50+
// &alpha, AT2.ptr(), &nw1, T1.ptr(), &nw1, &beta, TAT, &nw2);
51+
}
52+
53+
}
54+
55+
}

0 commit comments

Comments
 (0)