Skip to content

Commit 8cdc4b9

Browse files
authored
[Doc] Add im2col + gemm 实现 卷积算子 (#32)
1 parent d07e7b7 commit 8cdc4b9

File tree

9 files changed

+1141
-3
lines changed

9 files changed

+1141
-3
lines changed

docs/12_convolution/01_naive_conv/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ int w; // 数据宽
5454
int k; // 卷积核数量
5555
int r; // 卷积核高
5656
int s; // 卷积核宽
57-
int u; // 卷积在高方向上的步长
57+
int u; // 卷积在高方向上的步长
5858
int v; // 卷积在宽方向上的步长
5959
int p; // 卷积在高方向上的补边
6060
int q; // 卷积在宽方向上的补边

docs/12_convolution/02_intro_conv_optimize/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
上一篇文章中,我们介绍了卷积算子的简易实现,它是直接模拟卷积操作的过程,这种实现方式的缺点是计算量大,效率低。在本文中,我们将介绍卷积算子的优化思路。
44

5-
卷积算子的主要优化思路就是将卷积运算转换为矩阵乘法运算。进而卷积算子优化问题就转化为了矩阵乘法优化问题。这篇文章中我们主要介绍一下如何将卷积运算转换为矩阵乘法运算。
5+
卷积算子的主要优化思路就是将卷积运算转换为矩阵乘法运算。进而卷积算子优化问题就转化为了矩阵乘法优化问题卷积算子的主要优化思路就是将卷积运算转换为矩阵乘法运算。进而卷积算子优化问题就转化为了矩阵乘法优化问题。这篇文章中我们主要介绍一下如何将卷积运算转换为矩阵乘法运算。
66

77
## 1. 卷积算法映射为矩阵乘法
88

docs/12_convolution/03_im2col_conv/README.md

Lines changed: 394 additions & 1 deletion
Large diffs are not rendered by default.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
CC=nvcc
2+
3+
CXXFLAGS += -DNDEBUG -DUSE_DEFAULT_STDLIB -g
4+
5+
INCLUDES += -I./include
6+
7+
LDFLAGS = -gencode arch=compute_75,code=sm_75 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61, -gencode arch=compute_70,code=sm_70
8+
9+
# 获取当前目录下的cu文件集,放在变量CUR_SOURCE中
10+
CUR_SOURCE=${wildcard ./src/*.cu}
11+
12+
# 将对应的cu文件名转为o文件后放在下面的CUR_OBJS变量中
13+
CUR_OBJS=${patsubst %.cu, %.o, $(CUR_SOURCE)}
14+
15+
EXECUTABLE=conv2ddemo
16+
17+
all: $(EXECUTABLE)
18+
19+
$(EXECUTABLE): $(CUR_OBJS)
20+
$(CC) $(CUR_OBJS) $(LDFLAGS) -o $(EXECUTABLE)
21+
22+
%.o: %.cu
23+
$(CC) -c $< $(CXXFLAGS) $(INCLUDES) -o $@ -Xptxas -v -lineinfo --std=c++11 ${LDFLAGS}
24+
25+
clean:
26+
rm -f $(EXECUTABLE)
27+
rm -f ./src/*.o
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#ifndef __CONV2D_FWD_HEADER__
2+
#define __CONV2D_FWD_HEADER__
3+
4+
#define __in__
5+
#define __out__
6+
#define __in_out__
7+
8+
typedef struct
9+
{
10+
float *in; // 输入数据地址
11+
float *weight; // 权值数据地址
12+
float *out; // 输出数据地址
13+
unsigned int n; // batch szie default value 1
14+
unsigned int c; // channel number default value 32
15+
unsigned int h; // 数据高 default value 32
16+
unsigned int w; // 数据宽 default value 32
17+
unsigned int k; // 卷积核数量 default value 32
18+
unsigned int r; // 卷积核高 default value 1
19+
unsigned int s; // 卷积核宽 default value 1
20+
unsigned int u; // 卷积在高方向上的步长 default value 1
21+
unsigned int v; // 卷积在宽方向上的步长 default value 1
22+
unsigned int p; // 卷积在高方向上的补边 default value 0
23+
unsigned int q; // 卷积在宽方向上的补边 default value 0
24+
} problem_t;
25+
26+
typedef struct
27+
{
28+
unsigned int blockx; // blockx number
29+
unsigned int blocky; // blocky number
30+
unsigned int blockz; // blockz number
31+
unsigned int threadx; // threadx number per block
32+
unsigned int thready; // thready number per block
33+
unsigned int threadz; // threadz number per block
34+
unsigned int dynmicLdsSize; // 动态分配的lds大小,如果不使用动态分配的lds,则该值为0;
35+
void *kernelPtr; // kernel ptr
36+
} kernelInfo_t;
37+
38+
int getParamsize(__in__ problem_t *problem, __out__ int *paramSize);
39+
int getkernelInfo(__in__ problem_t *problem, __out__ kernelInfo_t *kernelInfo, __in_out__ void *param);
40+
41+
#endif
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#ifndef __VERFIY_HEADER__
2+
#define __VERFIY_HEADER__
3+
4+
float getPrecision(float tmp)
5+
{
6+
int tmpInt = (int)tmp;
7+
float eNum = 1.0e-6;
8+
if (abs(tmpInt) > 0)
9+
{
10+
while (tmpInt != 0)
11+
{
12+
tmpInt = (int)(tmpInt / 10);
13+
eNum *= 10;
14+
}
15+
}
16+
else
17+
{
18+
19+
if (tmp == 0)
20+
return eNum;
21+
22+
eNum = 1.0e-5;
23+
24+
while (tmpInt == 0)
25+
{
26+
tmp *= 10;
27+
tmpInt = (int)(tmp);
28+
eNum /= 10;
29+
}
30+
}
31+
return eNum;
32+
}
33+
34+
void conv2dcpu(float *pin, float *pwei, float *pout, int n, int c, int h, int w, int k, int r, int s, int u, int v, int p, int q)
35+
{
36+
int oh = (h + 2 * p - r) / u + 1;
37+
int ow = (w + 2 * q - s) / v + 1;
38+
39+
for (int nNum = 0; nNum < n; nNum++)
40+
{
41+
for (int kNum = 0; kNum < k; kNum++)
42+
{
43+
for (int i = 0; i < oh; i++)
44+
{
45+
for (int j = 0; j < ow; j++)
46+
{
47+
double sum = 0.0;
48+
int posh = i * u - p;
49+
int posw = j * v - q;
50+
51+
for (int cNum = 0; cNum < c; cNum++)
52+
{
53+
for (int khNum = 0; khNum < r; khNum++)
54+
{
55+
for (int kwNum = 0; kwNum < s; kwNum++)
56+
{
57+
int posh_ori = posh + khNum;
58+
int posw_ori = posw + kwNum;
59+
if (posw_ori >= 0 && posh_ori >= 0 && posw_ori < w && posh_ori < h)
60+
{
61+
sum += (double)(pin[nNum * c * h * w + cNum * (w * h) + posh_ori * w + posw_ori] * pwei[kNum * r * s * c + cNum * r * s + khNum * s + kwNum]);
62+
}
63+
}
64+
}
65+
}
66+
67+
pout[nNum * k * oh * ow + kNum * oh * ow + i * ow + j] = (float)sum;
68+
}
69+
}
70+
}
71+
}
72+
}
73+
#endif
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/bash
2+
make clean
3+
make
4+
5+
./conv2ddemo 128 3 225 225 32 3 3 2 2 0 0
6+
./conv2ddemo 49 128 35 35 384 3 3 2 2 0 0
7+
./conv2ddemo 16 128 105 105 256 3 3 2 2 0 0
8+
./conv2ddemo 128 3 230 230 64 7 7 2 2 0 0
9+
./conv2ddemo 2 3 838 1350 64 7 7 2 2 0 0
10+
./conv2ddemo 256 256 28 28 256 2 2 2 2 0 0
11+
./conv2ddemo 128 3 225 225 32 3 3 1 1 0 0

0 commit comments

Comments
 (0)