Skip to content

Commit c5ff3ab

Browse files
committed
1. add class Tensor_Multiply
2. add LRI::cal_loop3()
1 parent 7992040 commit c5ff3ab

File tree

14 files changed

+1093
-21
lines changed

14 files changed

+1093
-21
lines changed

include/RI/global/Tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class Tensor
5454

5555
bool empty() const { return shape.empty(); }
5656

57-
//Tensor & operator += (const Tensor &);
57+
Tensor & operator += (const Tensor &);
5858
Tensor operator-() const;
5959

6060
template <class Archive> void serialize( Archive & ar ){ ar(shape, data); } // for cereal

include/RI/global/Tensor.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,15 +151,13 @@ Tensor<T> operator* (const Tensor<T> &t1, const T &t2)
151151
return t;
152152
}
153153

154-
/*
155154
template<typename T>
156155
Tensor<T> &Tensor<T>::operator+= (const Tensor &t)
157156
{
158157
assert(same_shape(*this,t));
159158
*this->data += *t.data;
160159
return *this;
161160
}
162-
*/
163161

164162
template<typename T>
165163
Tensor<T> Tensor<T>::transpose() const
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// ===================
2+
// Author: Peize Lin
3+
// date: 2023.08.03
4+
// ===================
5+
6+
#pragma once
7+
8+
#include "Tensor_Multiply.h"
9+
10+
namespace RI
11+
{
12+
13+
namespace Tensor_Multiply
14+
{
15+
// Txy(x1,y0,y1) = Tx(a,x1) * Ty(y0,y1,a)
16+
template<typename Tdata>
17+
Tensor<Tdata> x1y0y1_x0y2(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
18+
{
19+
assert(Tx.shape.size()==2);
20+
assert(Ty.shape.size()==3);
21+
const std::size_t y01 = Ty.shape[0] * Ty.shape[1];
22+
Tensor<Tdata> Txy({Tx.shape[1], Ty.shape[0], Ty.shape[1]});
23+
Blas_Interface::gemm(
24+
'T', 'T', Tx.shape[1], y01, Tx.shape[0],
25+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
26+
Tdata(0.0), Txy.ptr());
27+
return Txy;
28+
}
29+
}
30+
31+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// ===================
2+
// Author: Peize Lin
3+
// date: 2023.08.03
4+
// ===================
5+
6+
#pragma once
7+
8+
#include "Tensor_Multiply.h"
9+
10+
namespace RI
11+
{
12+
13+
namespace Tensor_Multiply
14+
{
15+
// Txy(x1,x2,y1) = Tx(a,x1,x2) * Ty(a,y1)
16+
template<typename Tdata>
17+
Tensor<Tdata> x1x2y1_x0y0(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
18+
{
19+
assert(Tx.shape.size()==3);
20+
assert(Ty.shape.size()==2);
21+
assert(Tx.shape[0]==Ty.shape[0]);
22+
const std::size_t x12 = Tx.shape[1] * Tx.shape[2];
23+
Tensor<Tdata> Txy({Tx.shape[1], Tx.shape[2], Ty.shape[1]});
24+
Blas_Interface::gemm(
25+
'T', 'N', x12, Ty.shape[1], Tx.shape[0],
26+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
27+
Tdata(0.0), Txy.ptr());
28+
return Txy;
29+
}
30+
31+
// Txy(x0,x1,y0) = Tx(x0,x1,a) * Ty(y0,a)
32+
template<typename Tdata>
33+
Tensor<Tdata> x0x1y0_x2y1(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
34+
{
35+
assert(Tx.shape.size()==3);
36+
assert(Ty.shape.size()==2);
37+
assert(Tx.shape[2]==Ty.shape[1]);
38+
const std::size_t x01 = Tx.shape[0] * Tx.shape[1];
39+
Tensor<Tdata> Txy({Tx.shape[0], Tx.shape[1], Ty.shape[0]});
40+
Blas_Interface::gemm(
41+
'N', 'T', x01, Ty.shape[0], Tx.shape[2],
42+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
43+
Tdata(0.0), Txy.ptr());
44+
return Txy;
45+
}
46+
47+
// Txy(x1,x2,y0) = Tx(a,x1,x2) * Ty(y0,a)
48+
template<typename Tdata>
49+
Tensor<Tdata> x1x2y0_x0y1(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
50+
{
51+
assert(Tx.shape.size()==3);
52+
assert(Ty.shape.size()==2);
53+
const std::size_t x12 = Tx.shape[1] * Tx.shape[2];
54+
Tensor<Tdata> Txy({Tx.shape[1], Tx.shape[2], Ty.shape[0]});
55+
Blas_Interface::gemm(
56+
'T', 'T', x12, Ty.shape[0], Tx.shape[0],
57+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
58+
Tdata(0.0), Txy.ptr());
59+
return Txy;
60+
}
61+
62+
// Txy(x0,x1,y1) = Tx(x0,x1,a) * Ty(a,y1)
63+
template<typename Tdata>
64+
Tensor<Tdata> x0x1y1_x2y0(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
65+
{
66+
assert(Tx.shape.size()==3);
67+
assert(Ty.shape.size()==2);
68+
const std::size_t x01 = Tx.shape[0] * Tx.shape[1];
69+
Tensor<Tdata> Txy({Tx.shape[0], Tx.shape[1], Ty.shape[1]});
70+
Blas_Interface::gemm(
71+
'N', 'N', x01, Ty.shape[1], Tx.shape[2],
72+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
73+
Tdata(0.0), Txy.ptr());
74+
return Txy;
75+
}
76+
}
77+
78+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// ===================
2+
// Author: Peize Lin
3+
// date: 2023.08.03
4+
// ===================
5+
6+
#pragma once
7+
8+
#include "Tensor_Multiply.h"
9+
10+
namespace RI
11+
{
12+
13+
namespace Tensor_Multiply
14+
{
15+
// Txy(x0,y2) = Tx(x0,a,b) * Ty(a,b,y2)
16+
template<typename Tdata>
17+
Tensor<Tdata> x0y2_x1y0_x2y1(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
18+
{
19+
assert(Tx.shape.size()==3);
20+
assert(Ty.shape.size()==3);
21+
assert(Tx.shape[1]==Ty.shape[0]);
22+
assert(Tx.shape[2]==Ty.shape[1]);
23+
const std::size_t x12 = Tx.shape[1] * Tx.shape[2];
24+
Tensor<Tdata> Txy({Tx.shape[0], Ty.shape[2]});
25+
Blas_Interface::gemm(
26+
'N', 'N', Tx.shape[0], Ty.shape[2], x12,
27+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
28+
Tdata(0.0), Txy.ptr());
29+
return Txy;
30+
}
31+
32+
// Txy(x2,y0) = Tx(a,b,x2) * Ty(y0,a,b)
33+
template<typename Tdata>
34+
Tensor<Tdata> x2y0_x0y1_x1y2(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
35+
{
36+
assert(Tx.shape.size()==3);
37+
assert(Ty.shape.size()==3);
38+
assert(Tx.shape[0]==Ty.shape[1]);
39+
assert(Tx.shape[1]==Ty.shape[2]);
40+
const std::size_t x01 = Tx.shape[0] * Tx.shape[1];
41+
Tensor<Tdata> Txy({Tx.shape[2], Ty.shape[0]});
42+
Blas_Interface::gemm(
43+
'T', 'T', Tx.shape[2], Ty.shape[0], x01,
44+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
45+
Tdata(0.0), Txy.ptr());
46+
return Txy;
47+
}
48+
}
49+
50+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// ===================
2+
// Author: Peize Lin
3+
// date: 2023.08.03
4+
// ===================
5+
6+
#pragma once
7+
8+
#include "Tensor.h"
9+
#include "Blas_Interface-Contiguous.h"
10+
11+
#include "Tensor_Multiply-23.hpp"
12+
#include "Tensor_Multiply-32.hpp"
13+
#include "Tensor_Multiply-33.hpp"

include/RI/physics/Exx.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ void Exx<TA,Tcell,Ndim,Tdata>::cal_Hs(
118118

119119
std::vector<std::map<TA, std::map<TAC, Tensor<Tdata>>>> Hs_vec(1);
120120
this->lri.coefficients = {nullptr};
121-
this->lri.cal(
121+
this->lri.cal_loop3(
122122
{Label::ab_ab::a0b0_a1b1,
123123
Label::ab_ab::a0b0_a1b2,
124124
Label::ab_ab::a0b0_a2b1,

include/RI/physics/GW.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void G0W0<TA, Tcell, Ndim, Tdata>::cal_Sigc(
5959

6060
std::vector<std::map<TA, std::map<TAC, Tensor<Tdata>>>> Sigc_vec(1);
6161
this->lri.coefficients = {nullptr};
62-
this->lri.cal(
62+
this->lri.cal_loop3(
6363
{Label::ab_ab::a0b0_a1b1,
6464
Label::ab_ab::a0b0_a1b2,
6565
Label::ab_ab::a0b0_a2b1,

0 commit comments

Comments
 (0)