1+ #include < RI/global/Tensor.h>
2+ namespace RI
3+ {
4+ namespace Sym
5+ {
6+ template <typename T>
7+ inline void T1_HR (T* TA, const T* A, const Tensor<T>& T1, const int & n2)
8+ {
9+ // C' = T1^\dagger * C
10+ const int & n1 = T1.shape [0 ];
11+ Blas_Interface::gemm (' C' , ' N' , n1, n2, n1,
12+ T (1 ), T1.ptr (), n1, A, n2, T (0 ), TA, n2);
13+ // zgemm_(¬rans, &dagger, &n12, &nabf, &nabf,
14+ // &alpha, A, &n12, T1.ptr(), &nabf, &beta, TA, &n12);
15+ }
16+
17+ template <typename T>
18+ inline void T1_HR_T2 (T* TAT, const T* A, const Tensor<T>& T1, const Tensor<T>& T2)
19+ {
20+ // H' = T1^\dagger * H * T2
21+ const int & n2 = T2.shape [0 ], & n1 = T1.shape [0 ];
22+ const RI::Shape_Vector& shape = { static_cast <size_t >(n1),static_cast <size_t >(n2) };
23+ RI::Tensor<T> AT2 (shape);
24+ Blas_Interface::gemm (' N' , ' N' , n1, n2, n2,
25+ T (1 ), A, n2, T2.ptr (), n2, T (0 ), AT2.ptr (), n2);
26+ Blas_Interface::gemm (' C' , ' N' , n1, n2, n1,
27+ T (1 ), T1.ptr (), n1, AT2.ptr (), n2, T (0 ), TAT, n2);
28+ // col-major version
29+ // zgemm_(¬rans, ¬rans, &n2, &n1, &n2,
30+ // &alpha, T2.ptr(), &n2, A, &n2, &beta, AT2.ptr(), &n2);
31+ // zgemm_(¬rans, &dagger, &n2, &n1, &n1,
32+ // &alpha, AT2.ptr(), &n2, T1.ptr(), &n1, &beta, TAT, &n2);
33+ }
34+
35+ template <typename T>
36+ inline void T1_DR_T2 (T* TAT, const T* A, const Tensor<T>& T1, const Tensor<T>& T2)
37+ {
38+ // D' = T1^T * D * T2^* = T1^T * [T2^\dagger * D^T]^T
39+ const int & n2 = T2.shape [0 ], & n1 = T1.shape [0 ];
40+ const RI::Shape_Vector& shape = { static_cast <size_t >(n1),static_cast <size_t >(n2) };
41+ RI::Tensor<T> AT2 (shape);
42+ BlasConnector::gemm (' C' , ' T' , n2, n1, n2,
43+ T (1 ), T2.ptr (), n2, A, n2, T (0 ), AT2.ptr (), n1);
44+ BlasConnector::gemm (' T' , ' T' , n1, n2, n1,
45+ T (1 ), T1.ptr (), n1, AT2.ptr (), n1, T (0 ), TAT, n2);
46+ // col-major version
47+ // zgemm_(&transpose, &dagger, &nw1, &nw2, &nw2,
48+ // &alpha, A, &nw2, T2.ptr(), &nw2, &beta, AT2.ptr(), &nw1);
49+ // zgemm_(&transpose, &transpose, &nw2, &nw1, &nw1,
50+ // &alpha, AT2.ptr(), &nw1, T1.ptr(), &nw1, &beta, TAT, &nw2);
51+ }
52+
53+ }
54+
55+ }
0 commit comments