Skip to content

Commit 392cf2e

Browse files
author
Anirudh B H
committed
Test for vanilla Conv1d layer
1 parent 0878200 commit 392cf2e

File tree

4 files changed

+58
-3
lines changed

4 files changed

+58
-3
lines changed

c_reference/include/conv_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@
55
#include <float.h>
66

77
int prepareLowRankConvMat(float* out, float* W1, float* W2, unsigned rank, unsigned I, unsigned J);
8+
float sigmoid(float x);
9+
float relu(float x);
810

911
#endif

c_reference/src/conv_utils.c

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,22 @@
33
#include <float.h>
44

55
int prepareLowRankConvMat(float* out, float* W1, float* W2, unsigned rank, unsigned I, unsigned J){
6-
for(i = 0 ; i < I, i++){
7-
for(j = 0 ; j < J, j++){
6+
for(int i = 0 ; i < I; i++){
7+
for(int j = 0 ; j < J; j++){
88
float sum = 0;
9-
for(k = 0; k < rank ; k++){
9+
for(int k = 0; k < rank ; k++){
1010
sum += (W1[i * rank + k] * W2[k * J + j]);
1111
}
1212
out[i * J + j] = sum;
1313
}
1414
}
15+
}
16+
17+
float relu(float x) {
18+
if (x < 0.0) return 0.0;
19+
else return x;
20+
}
21+
22+
float sigmoid(float x) {
23+
return 1.0f / (1.0f + expf(-1.0f * x));
1524
}

c_reference/tests/conv1d/conv_param.h

Lines changed: 17 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include<stdio.h>
2+
#include<stdlib.h>
3+
4+
#include"conv_param.h"
5+
#include"conv1d.h"
6+
#include"conv_utils.h"
7+
8+
int main(){
9+
10+
ConvLayers_Params conv_params = {
11+
.W = CONV_WEIGHT,
12+
.B = CONV_BIAS,
13+
};
14+
15+
float pred[O_T * O_F] = {};
16+
Conv1D(pred, O_T, O_F, INPUT, 1, I_T, I_F, PAD, FILT, &conv_params, ACT);
17+
float error = 0;
18+
for(int t = 0 ; t < O_T ; t++){
19+
for(int d = 0 ; d < O_F ; d++){
20+
error += ((pred[t * O_F + d] - OUTPUT[t * O_F + d]) * (pred[t * O_F + d] - OUTPUT[t * O_F + d]));
21+
}
22+
}
23+
float avg_error = error/(O_T*O_F);
24+
printf("%f \t %f \n", error, avg_error);
25+
26+
return 0 ;
27+
}

0 commit comments

Comments
 (0)