Skip to content

Commit 442acff

Browse files
Adding the seq-2-seq FastGRNN verification for a one brick input
1 parent 5fdaba6 commit 442acff

File tree

7 files changed

+110
-5
lines changed

7 files changed

+110
-5
lines changed

c_reference/tests/Makefile

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@ INCLUDE_DIR=../include
77
SRC_DIR=../src
88
IFLAGS = -I $(INCLUDE_DIR)
99

10-
all: test_kws test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_dscnn_lr_depth_point test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv
10+
all: test_rnn test_postcnn test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_dscnn_lr_depth_point test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv
1111

1212
KWS_DIR=kws
13-
test_kws: $(KWS_DIR)/test_kws.c $(SRC_DIR)/conv_utils.o $(SRC_DIR)/conv1d.o $(SRC_DIR)/dscnn.o
13+
test_postcnn: $(KWS_DIR)/test_postcnn.c $(SRC_DIR)/conv_utils.o $(SRC_DIR)/conv1d.o $(SRC_DIR)/dscnn.o
14+
$(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm
15+
test_rnn: $(KWS_DIR)/test_rnn.c $(SRC_DIR)/fastgrnn.o $(SRC_DIR)/utils.o
1416
$(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm
1517

1618
DSCNN_DIR=dscnn
@@ -54,7 +56,7 @@ test_quantized_mbconv: $(MBCONV_DIR)/test_quantized_mbconv.c $(SRC_DIR)/quantize
5456
.PHONY: clean cleanest
5557

5658
clean:
57-
rm -f *.o *.gch test_kws test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_dscnn_lr_depth_point test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv
59+
rm -f *.o *.gch test_rnn test_postcnn test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_dscnn_lr_depth_point test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv
5860

5961
cleanest: clean
6062
rm *~
File renamed without changes.
File renamed without changes.

c_reference/tests/kws/rnn_io.h

Lines changed: 9 additions & 0 deletions
Large diffs are not rendered by default.

c_reference/tests/kws/rnn_params.h

Lines changed: 20 additions & 0 deletions
Large diffs are not rendered by default.

c_reference/tests/kws/test_kws.c renamed to c_reference/tests/kws/test_postcnn.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include<stdio.h>
22
#include<stdlib.h>
33

4-
#include"kws_params.h"
5-
#include"kws_io.h"
4+
#include"postcnn_params.h"
5+
#include"postcnn_io.h"
66
#include"conv1d.h"
77
#include"dscnn.h"
88
#include"conv_utils.h"

c_reference/tests/kws/test_rnn.c

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#include<stdio.h>
2+
#include<stdlib.h>
3+
#include<string.h>
4+
#include"rnn_io.h"
5+
#include"rnn_params.h"
6+
#include"fastgrnn.h"
7+
#include"utils.h"
8+
// #include"dscnn.h"
9+
// #include"conv_utils.h"
10+
11+
int main(){
12+
13+
FastGRNN_LR_Params RNN_params = {
14+
.mean = 0,
15+
.stdDev = 0,
16+
.W1 = F_W1,
17+
.W2 = F_W2,
18+
.wRank = LOW_RANK,
19+
.U1 = F_U1,
20+
.U2 = F_U2,
21+
.uRank = LOW_RANK,
22+
.Bg = F_BIAS_GATE,
23+
.Bh = F_BIAS_UPDATE,
24+
.sigmoid_zeta = sigmoid(F_ZETA),
25+
.sigmoid_nu = sigmoid(F_NU)
26+
};
27+
28+
float preComp[O_F] = { 0.0 };
29+
float tempLRW[LOW_RANK] = { 0.0 };
30+
float tempLRU[LOW_RANK] = { 0.0 };
31+
float normFeatures[I_F] = { 0.0 };
32+
FastGRNN_LR_Buffers buffers = {
33+
.preComp = preComp,
34+
.tempLRW = tempLRW,
35+
.tempLRU = tempLRU,
36+
.normFeatures = normFeatures
37+
};
38+
39+
float pred[O_T * O_F] = {0.0};
40+
// float pred[O_F] = {0.0};
41+
// int fastgrnn_lr(float* const hiddenState, unsigned hiddenDims,
42+
// const float* const input, unsigned inputDims, unsigned steps,
43+
// const void* params, void* buffers, int backward, int normalize);
44+
45+
float* temp_hiddenstate = (float*)malloc(O_F*sizeof(float));
46+
for(int t = 0 ; t < I_T ; t++){
47+
fastgrnn_lr(temp_hiddenstate, O_F,
48+
INPUT + (t * I_F) , I_F, 1,
49+
&RNN_params, &buffers, 0, 0);
50+
memcpy(pred + (t * O_F), temp_hiddenstate, O_F*sizeof(float));
51+
}
52+
53+
// fastgrnn_lr(pred, O_F,
54+
// INPUT, I_F, I_T,
55+
// &RNN_params, &buffers, 0, 0);
56+
57+
// Calculate Error(Aggregate Squared and Mean Squared)
58+
float error = 0, denom = 0;
59+
for(int t = 0 ; t < O_T ; t++){
60+
for(int d = 0 ; d < O_F ; d++){
61+
error += ((pred[t * O_F + d] - OUTPUT[t * O_F + d]) * (pred[t * O_F + d] - OUTPUT[t * O_F + d]));
62+
// printf("%f %f\t", pred[t * O_F + d], OUTPUT[t * O_F + d]);
63+
// error += ((pred[d] - OUTPUT[t * O_F + d]) * (pred[d] - OUTPUT[t * O_F + d]));
64+
// printf("%f %f\n", pred[d], OUTPUT[t * O_F + d]);
65+
denom += OUTPUT[t * O_F + d] * OUTPUT[t * O_F + d];
66+
}
67+
}
68+
float avg_error = error/(O_T*O_F);
69+
printf("RNN Block\n");
70+
printf("Aggregate Squared Error : %f ; Mean Sqaured Error : %f \n", error, avg_error);
71+
printf("RMS : %f \n", error/denom);
72+
73+
return 0 ;
74+
}

0 commit comments

Comments
 (0)