Skip to content

Commit f8dcacf

Browse files
committed
rotation funcs
1 parent 91dc587 commit f8dcacf

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

include/RI/global/Tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class Tensor
4545
inline T& operator() (const std::size_t i0, const std::size_t i1, const std::size_t i2, const std::size_t i3) const;
4646

4747
Tensor transpose() const;
48+
Tensor dagger() const;
4849

4950
// ||d||_p = (|d_1|^p+|d_2|^p+...)^{1/p}
5051
// if(p==std::numeric_limits<double>::max()) ||d||_max = max_i |d_i|

include/RI/global/Tensor.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,17 @@ Tensor<T> Tensor<T>::transpose() const
178178
return t;
179179
}
180180

181+
template<typename T>
182+
Tensor<T> Tensor<T>::dagger() const
183+
{
184+
assert(this->shape.size() == 2);
185+
Tensor<T> t({ this->shape[1], this->shape[0] });
186+
for (std::size_t i0 = 0; i0 < this->shape[0]; ++i0)
187+
for (std::size_t i1 = 0; i1 < this->shape[1]; ++i1)
188+
t(i1, i0) = std::conj((*this)(i0, i1));
189+
return t;
190+
}
191+
181192
template<typename T>
182193
Global_Func::To_Real_t<T> Tensor<T>::norm(const double p) const
183194
{
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include <RI/global/Tensor.h>
2+
namespace RI
3+
{
4+
namespace Sym
5+
{
6+
template <typename T>
7+
inline void T1_HR(T* TA, const T* A, const Tensor<T>& T1, const int& n2)
8+
{
9+
// C' = T1^\dagger * C
10+
const int& n1 = T1.shape[0];
11+
Blas_Interface::gemm('C', 'N', n1, n2, n1,
12+
T(1), T1.ptr(), n1, A, n2, T(0), TA, n2);
13+
// zgemm_(&notrans, &dagger, &n12, &nabf, &nabf,
14+
// &alpha, A, &n12, T1.ptr(), &nabf, &beta, TA, &n12);
15+
}
16+
17+
template <typename T>
18+
inline void T1_HR_T2(T* TAT, const T* A, const Tensor<T>& T1, const Tensor<T>& T2)
19+
{
20+
// H' = T1^\dagger * H * T2
21+
const int& n2 = T2.shape[0], & n1 = T1.shape[0];
22+
const RI::Shape_Vector& shape = { static_cast<size_t>(n1),static_cast<size_t>(n2) };
23+
RI::Tensor<T> AT2(shape);
24+
Blas_Interface::gemm('N', 'N', n1, n2, n2,
25+
T(1), A, n2, T2.ptr(), n2, T(0), AT2.ptr(), n2);
26+
Blas_Interface::gemm('C', 'N', n1, n2, n1,
27+
T(1), T1.ptr(), n1, AT2.ptr(), n2, T(0), TAT, n2);
28+
// col-major version
29+
// zgemm_(&notrans, &notrans, &n2, &n1, &n2,
30+
// &alpha, T2.ptr(), &n2, A, &n2, &beta, AT2.ptr(), &n2);
31+
// zgemm_(&notrans, &dagger, &n2, &n1, &n1,
32+
// &alpha, AT2.ptr(), &n2, T1.ptr(), &n1, &beta, TAT, &n2);
33+
}
34+
35+
template<typename T>
36+
inline void T1_DR_T2(T* TAT, const T* A, const Tensor<T>& T1, const Tensor<T>& T2)
37+
{
38+
// D' = T1^T * D * T2^* = T1^T * [T2^\dagger * D^T]^T
39+
const int& n2 = T2.shape[0], & n1 = T1.shape[0];
40+
const RI::Shape_Vector& shape = { static_cast<size_t>(n1),static_cast<size_t>(n2) };
41+
RI::Tensor<T> AT2(shape);
42+
BlasConnector::gemm('C', 'T', n2, n1, n2,
43+
T(1), T2.ptr(), n2, A, n2, T(0), AT2.ptr(), n1);
44+
BlasConnector::gemm('T', 'T', n1, n2, n1,
45+
T(1), T1.ptr(), n1, AT2.ptr(), n1, T(0), TAT, n2);
46+
// col-major version
47+
// zgemm_(&transpose, &dagger, &nw1, &nw2, &nw2,
48+
// &alpha, A, &nw2, T2.ptr(), &nw2, &beta, AT2.ptr(), &nw1);
49+
// zgemm_(&transpose, &transpose, &nw2, &nw1, &nw1,
50+
// &alpha, AT2.ptr(), &nw1, T1.ptr(), &nw1, &beta, TAT, &nw2);
51+
}
52+
53+
}
54+
55+
}

0 commit comments

Comments
 (0)