Skip to content

Commit a1d0733

Browse files
committed
1. add all Label::ab_ab for LRI::cal_loop3()
1 parent 6208db9 commit a1d0733

File tree

7 files changed

+894
-68
lines changed

7 files changed

+894
-68
lines changed

include/RI/global/Tensor_Multiply-23.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,23 @@ namespace Tensor_Multiply
2929
return Txy;
3030
}
3131

32+
// Txy(x0,y0,y1) = Tx(x0,a) * Ty(y0,y1,a)
33+
template<typename Tdata>
34+
Tensor<Tdata> x0y0y1_x0a_y0y1a(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
35+
{
36+
assert(Tx.shape.size()==2);
37+
assert(Ty.shape.size()==3);
38+
Tensor<Tdata> Txy({Tx.shape[0], Ty.shape[0], Ty.shape[1]});
39+
Blas_Interface::gemm(
40+
'N', 'T',
41+
Tx.shape[0],
42+
Ty.shape[0] * Ty.shape[1],
43+
Tx.shape[1],
44+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
45+
Tdata(0.0), Txy.ptr());
46+
return Txy;
47+
}
48+
3249
}
3350

3451
}

include/RI/global/Tensor_Multiply-33.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,23 @@ namespace Tensor_Multiply
4646
return Txy;
4747
}
4848

49+
// Txy(x2,y2) = Tx(a,b,x2) * Ty(a,b,y2)
50+
template<typename Tdata>
51+
Tensor<Tdata> x2y2_abx2_aby2(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
52+
{
53+
assert(Tx.shape.size()==3);
54+
assert(Ty.shape.size()==3);
55+
Tensor<Tdata> Txy({Tx.shape[2], Ty.shape[2]});
56+
Blas_Interface::gemm(
57+
'T', 'N',
58+
Tx.shape[2],
59+
Ty.shape[2],
60+
Tx.shape[0] * Tx.shape[1],
61+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
62+
Tdata(0.0), Txy.ptr());
63+
return Txy;
64+
}
65+
4966
}
5067

5168
}

0 commit comments

Comments
 (0)