Skip to content

Commit 8b2481f

Browse files
committed
1. update Tensor_Multiply
2. fix Tensor::operator()
1 parent bd299fa commit 8b2481f

12 files changed

+552
-69
lines changed

include/RI/global/Tensor.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ T& Tensor<T>::operator() (const std::size_t i0, const std::size_t i1, const std:
9090
template<typename T>
9191
T& Tensor<T>::operator() (const std::size_t i0, const std::size_t i1, const std::size_t i2, const std::size_t i3) const
9292
{
93-
assert(this->shape.size()==3);
93+
assert(this->shape.size()==4);
9494
assert(i0>=0); assert(i0<this->shape[0]);
9595
assert(i1>=0); assert(i1<this->shape[1]);
9696
assert(i2>=0); assert(i2<this->shape[2]);
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// ===================
2+
// Author: Peize Lin
3+
// date: 2023.08.03
4+
// ===================
5+
6+
#pragma once
7+
8+
#include "Tensor_Multiply.h"
9+
10+
namespace RI
11+
{
12+
13+
namespace Tensor_Multiply
14+
{
15+
// Txy(x0,y0) = Tx(x0,a) * Ty(y0,a)
16+
template<typename Tdata>
17+
Tensor<Tdata> x0y0_x0a_y0a(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
18+
{
19+
assert(Tx.shape.size()==2);
20+
assert(Ty.shape.size()==2);
21+
Tensor<Tdata> Txy({Tx.shape[0], Ty.shape[0]});
22+
Blas_Interface::gemm(
23+
'N', 'T',
24+
Tx.shape[0],
25+
Ty.shape[0],
26+
Tx.shape[1],
27+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
28+
Tdata(0.0), Txy.ptr());
29+
return Txy;
30+
}
31+
32+
// Txy(x0,y1) = Tx(x0,a) * Ty(a,y1)
33+
template<typename Tdata>
34+
Tensor<Tdata> x0y1_x0a_ay1(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
35+
{
36+
assert(Tx.shape.size()==2);
37+
assert(Ty.shape.size()==2);
38+
Tensor<Tdata> Txy({Tx.shape[0], Ty.shape[1]});
39+
Blas_Interface::gemm(
40+
'N', 'N',
41+
Tx.shape[0],
42+
Ty.shape[1],
43+
Tx.shape[1],
44+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
45+
Tdata(0.0), Txy.ptr());
46+
return Txy;
47+
}
48+
49+
// Txy(x1,y0) = Tx(a,x1) * Ty(y0,a)
50+
template<typename Tdata>
51+
Tensor<Tdata> x1y0_ax1_y0a(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
52+
{
53+
assert(Tx.shape.size()==2);
54+
assert(Ty.shape.size()==2);
55+
Tensor<Tdata> Txy({Tx.shape[1], Ty.shape[0]});
56+
Blas_Interface::gemm(
57+
'T', 'T',
58+
Tx.shape[1],
59+
Ty.shape[0],
60+
Tx.shape[0],
61+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
62+
Tdata(0.0), Txy.ptr());
63+
return Txy;
64+
}
65+
66+
// Txy(x1,y1) = Tx(a,x1) * Ty(a,y1)
67+
template<typename Tdata>
68+
Tensor<Tdata> x1y1_ax1_ay1(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
69+
{
70+
assert(Tx.shape.size()==2);
71+
assert(Ty.shape.size()==2);
72+
Tensor<Tdata> Txy({Tx.shape[1], Ty.shape[1]});
73+
Blas_Interface::gemm(
74+
'T', 'N',
75+
Tx.shape[1],
76+
Ty.shape[1],
77+
Tx.shape[0],
78+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
79+
Tdata(0.0), Txy.ptr());
80+
return Txy;
81+
}
82+
83+
}
84+
85+
}

include/RI/global/Tensor_Multiply-23.hpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,40 @@ namespace RI
1212

1313
namespace Tensor_Multiply
1414
{
15+
// Txy(x0,y0,y1) = Tx(x0,a) * Ty(y0,y1,a)
16+
template<typename Tdata>
17+
Tensor<Tdata> x0y0y1_x0a_y0y1a(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
18+
{
19+
assert(Tx.shape.size()==2);
20+
assert(Ty.shape.size()==3);
21+
Tensor<Tdata> Txy({Tx.shape[0], Ty.shape[0], Ty.shape[1]});
22+
Blas_Interface::gemm(
23+
'N', 'T',
24+
Tx.shape[0],
25+
Ty.shape[0] * Ty.shape[1],
26+
Tx.shape[1],
27+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
28+
Tdata(0.0), Txy.ptr());
29+
return Txy;
30+
}
31+
32+
// Txy(x0,y1,y2) = Tx(x0,a) * Ty(a,y1,y2)
33+
template<typename Tdata>
34+
Tensor<Tdata> x0y1y2_x0a_ay1y2(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
35+
{
36+
assert(Tx.shape.size()==2);
37+
assert(Ty.shape.size()==3);
38+
Tensor<Tdata> Txy({Tx.shape[0], Ty.shape[1], Ty.shape[2]});
39+
Blas_Interface::gemm(
40+
'N', 'N',
41+
Tx.shape[0],
42+
Ty.shape[1] * Ty.shape[2],
43+
Tx.shape[1],
44+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
45+
Tdata(0.0), Txy.ptr());
46+
return Txy;
47+
}
48+
1549
// Txy(x1,y0,y1) = Tx(a,x1) * Ty(y0,y1,a)
1650
template<typename Tdata>
1751
Tensor<Tdata> x1y0y1_ax1_y0y1a(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
@@ -29,18 +63,18 @@ namespace Tensor_Multiply
2963
return Txy;
3064
}
3165

32-
// Txy(x0,y0,y1) = Tx(x0,a) * Ty(y0,y1,a)
66+
// Txy(x1,y1,y2) = Tx(a,x1) * Ty(a,y1,y2)
3367
template<typename Tdata>
34-
Tensor<Tdata> x0y0y1_x0a_y0y1a(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
68+
Tensor<Tdata> x1y1y2_ax1_ay1y2(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
3569
{
3670
assert(Tx.shape.size()==2);
3771
assert(Ty.shape.size()==3);
38-
Tensor<Tdata> Txy({Tx.shape[0], Ty.shape[0], Ty.shape[1]});
72+
Tensor<Tdata> Txy({Tx.shape[1], Ty.shape[1], Ty.shape[2]});
3973
Blas_Interface::gemm(
40-
'N', 'T',
41-
Tx.shape[0],
42-
Ty.shape[0] * Ty.shape[1],
74+
'T', 'N',
4375
Tx.shape[1],
76+
Ty.shape[1] * Ty.shape[2],
77+
Tx.shape[0],
4478
Tdata(1.0), Tx.ptr(), Ty.ptr(),
4579
Tdata(0.0), Txy.ptr());
4680
return Txy;

include/RI/global/Tensor_Multiply-32.hpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,34 @@ namespace RI
1212

1313
namespace Tensor_Multiply
1414
{
15-
// Txy(x1,x2,y1) = Tx(a,x1,x2) * Ty(a,y1)
15+
// Txy(x0,x1,y0) = Tx(x0,x1,a) * Ty(y0,a)
1616
template<typename Tdata>
17-
Tensor<Tdata> x1x2y1_ax1x2_ay1(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
17+
Tensor<Tdata> x0x1y0_x0x1a_y0a(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
1818
{
1919
assert(Tx.shape.size()==3);
2020
assert(Ty.shape.size()==2);
21-
Tensor<Tdata> Txy({Tx.shape[1], Tx.shape[2], Ty.shape[1]});
21+
Tensor<Tdata> Txy({Tx.shape[0], Tx.shape[1], Ty.shape[0]});
2222
Blas_Interface::gemm(
23-
'T', 'N',
24-
Tx.shape[1] * Tx.shape[2],
25-
Ty.shape[1],
26-
Tx.shape[0],
23+
'N', 'T',
24+
Tx.shape[0] * Tx.shape[1],
25+
Ty.shape[0],
26+
Tx.shape[2],
2727
Tdata(1.0), Tx.ptr(), Ty.ptr(),
2828
Tdata(0.0), Txy.ptr());
2929
return Txy;
3030
}
3131

32-
// Txy(x0,x1,y0) = Tx(x0,x1,a) * Ty(y0,a)
32+
// Txy(x0,x1,y1) = Tx(x0,x1,a) * Ty(a,y1)
3333
template<typename Tdata>
34-
Tensor<Tdata> x0x1y0_x0x1a_y0a(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
34+
Tensor<Tdata> x0x1y1_x0x1a_ay1(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
3535
{
3636
assert(Tx.shape.size()==3);
3737
assert(Ty.shape.size()==2);
38-
Tensor<Tdata> Txy({Tx.shape[0], Tx.shape[1], Ty.shape[0]});
38+
Tensor<Tdata> Txy({Tx.shape[0], Tx.shape[1], Ty.shape[1]});
3939
Blas_Interface::gemm(
40-
'N', 'T',
40+
'N', 'N',
4141
Tx.shape[0] * Tx.shape[1],
42-
Ty.shape[0],
42+
Ty.shape[1],
4343
Tx.shape[2],
4444
Tdata(1.0), Tx.ptr(), Ty.ptr(),
4545
Tdata(0.0), Txy.ptr());
@@ -63,18 +63,18 @@ namespace Tensor_Multiply
6363
return Txy;
6464
}
6565

66-
// Txy(x0,x1,y1) = Tx(x0,x1,a) * Ty(a,y1)
66+
// Txy(x1,x2,y1) = Tx(a,x1,x2) * Ty(a,y1)
6767
template<typename Tdata>
68-
Tensor<Tdata> x0x1y1_x0x1a_ay1(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
68+
Tensor<Tdata> x1x2y1_ax1x2_ay1(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
6969
{
7070
assert(Tx.shape.size()==3);
7171
assert(Ty.shape.size()==2);
72-
Tensor<Tdata> Txy({Tx.shape[0], Tx.shape[1], Ty.shape[1]});
72+
Tensor<Tdata> Txy({Tx.shape[1], Tx.shape[2], Ty.shape[1]});
7373
Blas_Interface::gemm(
74-
'N', 'N',
75-
Tx.shape[0] * Tx.shape[1],
74+
'T', 'N',
75+
Tx.shape[1] * Tx.shape[2],
7676
Ty.shape[1],
77-
Tx.shape[2],
77+
Tx.shape[0],
7878
Tdata(1.0), Tx.ptr(), Ty.ptr(),
7979
Tdata(0.0), Txy.ptr());
8080
return Txy;

include/RI/global/Tensor_Multiply-33.hpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,23 @@ namespace RI
1212

1313
namespace Tensor_Multiply
1414
{
15+
// Txy(x0,y0) = Tx(x0,a,b) * Ty(y0,a,b)
16+
template<typename Tdata>
17+
Tensor<Tdata> x0y0_x0ab_y0ab(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
18+
{
19+
assert(Tx.shape.size()==3);
20+
assert(Ty.shape.size()==3);
21+
Tensor<Tdata> Txy({Tx.shape[0], Ty.shape[0]});
22+
Blas_Interface::gemm(
23+
'N', 'T',
24+
Tx.shape[0],
25+
Ty.shape[0],
26+
Tx.shape[1] * Tx.shape[2],
27+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
28+
Tdata(0.0), Txy.ptr());
29+
return Txy;
30+
}
31+
1532
// Txy(x0,y2) = Tx(x0,a,b) * Ty(a,b,y2)
1633
template<typename Tdata>
1734
Tensor<Tdata> x0y2_x0ab_aby2(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
@@ -63,6 +80,74 @@ namespace Tensor_Multiply
6380
return Txy;
6481
}
6582

83+
// Txy(x0,x1,y0,y1) = Tx(x0,x1,a) * Ty(y0,y1,a)
84+
template<typename Tdata>
85+
Tensor<Tdata> x0x1y0y1_x0x1a_y0y1a(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
86+
{
87+
assert(Tx.shape.size()==3);
88+
assert(Ty.shape.size()==3);
89+
Tensor<Tdata> Txy({Tx.shape[0], Tx.shape[1], Ty.shape[0], Ty.shape[1]});
90+
Blas_Interface::gemm(
91+
'N', 'T',
92+
Tx.shape[0] * Tx.shape[1],
93+
Ty.shape[0] * Ty.shape[1],
94+
Tx.shape[2],
95+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
96+
Tdata(0.0), Txy.ptr());
97+
return Txy;
98+
}
99+
100+
// Txy(x0,x1,y1,y2) = Tx(x0,x1,a) * Ty(a,y1,y2)
101+
template<typename Tdata>
102+
Tensor<Tdata> x0x1y1y2_x0x1a_ay1y2(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
103+
{
104+
assert(Tx.shape.size()==3);
105+
assert(Ty.shape.size()==3);
106+
Tensor<Tdata> Txy({Tx.shape[0], Tx.shape[1], Ty.shape[1], Ty.shape[2]});
107+
Blas_Interface::gemm(
108+
'N', 'N',
109+
Tx.shape[0] * Tx.shape[1],
110+
Ty.shape[1] * Ty.shape[2],
111+
Tx.shape[2],
112+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
113+
Tdata(0.0), Txy.ptr());
114+
return Txy;
115+
}
116+
117+
// Txy(x1,x2,y0,y1) = Tx(a,x1,x2) * Ty(y0,y1,a)
118+
template<typename Tdata>
119+
Tensor<Tdata> x1x2y0y1_ax1x2_y0y1a(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
120+
{
121+
assert(Tx.shape.size()==3);
122+
assert(Ty.shape.size()==3);
123+
Tensor<Tdata> Txy({Tx.shape[1], Tx.shape[2], Ty.shape[0], Ty.shape[1]});
124+
Blas_Interface::gemm(
125+
'T', 'T',
126+
Tx.shape[1] * Tx.shape[2],
127+
Ty.shape[0] * Ty.shape[1],
128+
Tx.shape[0],
129+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
130+
Tdata(0.0), Txy.ptr());
131+
return Txy;
132+
}
133+
134+
// Txy(x1,x2,y1,y2) = Tx(a,x1,x2) * Ty(a,y1,y2)
135+
template<typename Tdata>
136+
Tensor<Tdata> x1x2y1y2_ax1x2_ay1y2(const Tensor<Tdata> &Tx, const Tensor<Tdata> &Ty)
137+
{
138+
assert(Tx.shape.size()==3);
139+
assert(Ty.shape.size()==3);
140+
Tensor<Tdata> Txy({Tx.shape[1], Tx.shape[2], Ty.shape[1], Ty.shape[2]});
141+
Blas_Interface::gemm(
142+
'T', 'N',
143+
Tx.shape[1] * Tx.shape[2],
144+
Ty.shape[1] * Ty.shape[2],
145+
Tx.shape[0],
146+
Tdata(1.0), Tx.ptr(), Ty.ptr(),
147+
Tdata(0.0), Txy.ptr());
148+
return Txy;
149+
}
150+
66151
}
67152

68153
}

include/RI/global/Tensor_Multiply.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "Tensor.h"
99
#include "Blas_Interface-Contiguous.h"
1010

11+
#include "Tensor_Multiply-22.hpp"
1112
#include "Tensor_Multiply-23.hpp"
1213
#include "Tensor_Multiply-32.hpp"
1314
#include "Tensor_Multiply-33.hpp"

unittests/global/Tensor-test.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ std::ostream &operator<<(std::ostream &os, const Tensor<T> &t)
6161
os<<"]"<<std::endl;
6262
return os;
6363
}
64+
case 4:
65+
{
66+
os<<"{"<<std::endl;
67+
for(std::size_t i0=0; i0<t.shape[0]; ++i0)
68+
{
69+
os<<"["<<std::endl;
70+
for(std::size_t i1=0; i1<t.shape[1]; ++i1)
71+
{
72+
for(std::size_t i2=0; i2<t.shape[2]; ++i2)
73+
{
74+
for(std::size_t i3=0; i3<t.shape[3]; ++i3)
75+
os<<t(i0,i1,i2,i3)<<"\t";
76+
os<<std::endl;
77+
}
78+
os<<std::endl;
79+
}
80+
os<<"]"<<std::endl;
81+
}
82+
os<<"}"<<std::endl;
83+
return os;
84+
}
6485
default:
6586
throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__));
6687
}

0 commit comments

Comments
 (0)