Skip to content

Commit 1b660c5

Browse files
Created and tested both dscn blocks
1 parent ca28078 commit 1b660c5

File tree

7 files changed

+144
-7
lines changed

7 files changed

+144
-7
lines changed

c_reference/include/dscnn.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,10 @@ int DSCNN_LR(float* output_signal, float* input_signal, unsigned in_T, unsigned
1010
unsigned affine, float* gamma, float* beta, unsigned in_place, unsigned cnn_hidden, int cnn_padding, unsigned cnn_kernel_size,
1111
const void* cnn_params, int cnn_activations);
1212

13+
int DSCNN_LR_Point_Depth(float* output_signal, float* input_signal, unsigned in_T, unsigned in_channels, float* mean, float* var,
14+
unsigned affine, float* gamma, float* beta, unsigned in_place, unsigned depth_cnn_hidden, int depth_cnn_padding,
15+
unsigned depth_cnn_kernel_size, const void* depth_cnn_params, int depth_cnn_activations, unsigned point_cnn_hidden,
16+
int point_cnn_padding, unsigned point_cnn_kernel_size, const void* point_cnn_params, int point_cnn_activations,
17+
int pool_padding, unsigned pool_kernel_size, int pool_activation);
1318

1419
#endif

c_reference/src/dscnn.c

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
#include"dscnn.h"
22
#include"conv1d.h"
33
#include"conv_utils.h"
4+
#include<stdio.h>
45

56
int DSCNN_LR(float* output_signal, float* input_signal, unsigned in_T, unsigned in_channels, float* mean, float* var,
67
unsigned affine, float* gamma, float* beta, unsigned in_place, unsigned cnn_hidden, int cnn_padding, unsigned cnn_kernel_size,
78
const void* cnn_params, int cnn_activations){
8-
9-
float out_T;
10-
9+
unsigned out_T;
1110
// BatchNorm
1211
float* norm_out = (float*)malloc(in_T*in_channels*sizeof(float));
1312
BatchNorm1d(norm_out, input_signal, in_T, in_channels,
@@ -19,5 +18,52 @@ int DSCNN_LR(float* output_signal, float* input_signal, unsigned in_T, unsigned
1918
in_T, in_channels, cnn_padding, cnn_kernel_size,
2019
cnn_params, cnn_activations);
2120
free(norm_out);
21+
22+
return 0;
23+
}
24+
25+
int DSCNN_LR_Point_Depth(float* output_signal, float* input_signal, unsigned in_T, unsigned in_channels, float* mean, float* var,
26+
unsigned affine, float* gamma, float* beta, unsigned in_place, unsigned depth_cnn_hidden, int depth_cnn_padding,
27+
unsigned depth_cnn_kernel_size, const void* depth_cnn_params, int depth_cnn_activations, unsigned point_cnn_hidden,
28+
int point_cnn_padding, unsigned point_cnn_kernel_size, const void* point_cnn_params, int point_cnn_activations,
29+
int pool_padding, unsigned pool_kernel_size, int pool_activation){
30+
31+
// Activation
32+
unsigned out_T;
33+
float* act_out= (float*)malloc(in_T * (in_channels>>1) * sizeof(float));
34+
TanhGate(act_out, input_signal, in_T, in_channels);
35+
36+
// Norm
37+
in_channels >>= 1;
38+
// float* norm_out = (float*)malloc(in_T*in_channels*sizeof(float));
39+
BatchNorm1d(0, act_out, in_T, in_channels,
40+
mean, var, affine, gamma, beta, in_place);
41+
// free(act_out);
42+
43+
// Depth CNN
44+
out_T = in_T - depth_cnn_kernel_size + 2*depth_cnn_padding + 1;
45+
float* depth_out = (float*)malloc(out_T * depth_cnn_hidden * sizeof(float));
46+
Conv1D_Depth(depth_out, out_T, act_out,
47+
in_T, in_channels, depth_cnn_padding, depth_cnn_kernel_size,
48+
depth_cnn_params, depth_cnn_activations);
49+
// free(norm_out);
50+
free(act_out);
51+
52+
// Point CNN
53+
in_T = out_T;
54+
out_T = in_T - point_cnn_kernel_size + 2*point_cnn_padding + 1;
55+
float* point_out = (float*)malloc(out_T * point_cnn_hidden * sizeof(float));
56+
Conv1D_LR(point_out, out_T, point_cnn_hidden, depth_out,
57+
in_T, depth_cnn_hidden, point_cnn_padding, point_cnn_kernel_size,
58+
point_cnn_params, point_cnn_activations);
59+
free(depth_out);
60+
61+
// Pool
62+
in_T = out_T;
63+
out_T = in_T - pool_kernel_size + 2*pool_padding + 1;
64+
AvgPool1D(output_signal, out_T, point_out, in_T, point_cnn_hidden,
65+
pool_padding, pool_kernel_size, pool_activation);
66+
free(point_out);
67+
2268
return 0;
2369
}

c_reference/tests/Makefile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ INCLUDE_DIR=../include
77
SRC_DIR=../src
88
IFLAGS = -I $(INCLUDE_DIR)
99

10-
all: test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv
10+
all: test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_dscnn_lr_depth_point test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv
1111

1212
DSCNN_DIR=dscnn
1313
test_dscnn_lr: $(DSCNN_DIR)/test_dscnn_lr.c $(SRC_DIR)/conv_utils.o $(SRC_DIR)/conv1d.o $(SRC_DIR)/dscnn.o
1414
$(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm
15+
test_dscnn_lr_depth_point : $(DSCNN_DIR)/test_dscnn_lr_depth_point.c $(SRC_DIR)/conv_utils.o $(SRC_DIR)/conv1d.o $(SRC_DIR)/dscnn.o
16+
$(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm
1517

1618
CONV1D_DIR=conv1d
1719
test_conv1d: $(CONV1D_DIR)/conv1d_regular/test_conv1d.c $(SRC_DIR)/conv_utils.o $(SRC_DIR)/conv1d.o
@@ -48,7 +50,7 @@ test_quantized_mbconv: $(MBCONV_DIR)/test_quantized_mbconv.c $(SRC_DIR)/quantize
4850
.PHONY: clean cleanest
4951

5052
clean:
51-
rm -f *.o *.gch test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv
53+
rm -f *.o *.gch test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_dscnn_lr_depth_point test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv
5254

5355
cleanest: clean
5456
rm *~

c_reference/tests/dscnn/dscnn_lr.h renamed to c_reference/tests/dscnn/dscnn_param_lr.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
static float BNORM_CNN1_MEAN[I_F] = {10.719894, 11.277293, 11.355163, 11.932443, 12.354103, 12.917164, 13.324238, 13.575673, 13.637483, 13.515657, 13.448936, 13.534993, 13.772815, 14.070428, 14.2041855, 14.314233, 14.354123, 14.344929, 14.278268, 14.370038, 14.091814, 13.935598, 14.095416, 13.853564, 13.965211, 13.873084, 13.962391, 13.873242, 13.977497, 13.9614935, 14.016155, 14.080242, 14.141294, 14.232999, 14.415869, 14.547782, 14.620865, 14.728154, 14.8032675, 14.938353, 14.91555, 14.998761, 14.988357, 15.069421, 15.058382, 15.12057, 15.198176, 15.287275, 15.387576, 15.469517, 15.51293, 15.572873, 15.56175, 15.640288, 15.65414, 15.677645, 15.692691, 15.711035, 15.720091, 15.719153, 15.7531595, 15.746176, 15.745342, 15.679322, 15.626492, 15.573947, 15.527088, 15.491109, 15.432482, 15.394985, 15.371318, 15.341917, 15.331801, 15.335991, 15.337822, 15.325554, 15.294515, 15.265457, 15.205703, 15.127973};
1111

12+
// static float BNORM_CNN1_STDDEV[I_F] = {4.547185, 4.518235, 4.619436, 4.8532863, 4.970063, 5.115547, 5.2474685, 5.3445864, 5.3823757, 5.372092, 5.325787, 5.301521, 5.3339496, 5.3862495, 5.414599, 5.411404, 5.3931794, 5.364316, 5.336561, 5.3083525, 5.2459865, 5.1986265, 5.167714, 5.105501, 5.069721, 5.0293355, 4.999428, 4.9724503, 4.9639397, 4.9570994, 4.9546976, 4.9532127, 4.949258, 4.929101, 4.836479, 4.8515296, 4.964716, 4.978233, 4.9837084, 4.9872823, 4.9787045, 4.9738936, 4.9639654, 4.954813, 4.9563284, 4.966465, 4.981837, 5.0024242, 5.02435, 5.0360146, 5.0363855, 5.027525, 5.031958, 5.0330267, 5.039688, 5.0406976, 5.0352154, 5.0343814, 5.037096, 5.0439806, 5.0501595, 5.0427365, 5.0152955, 5.0615163, 5.068582, 5.0682364, 5.067946, 5.062103, 5.06264, 5.0651436, 5.060831, 5.063966, 5.061589, 5.0575705, 5.05634, 5.054741, 5.0505805, 5.0439153, 5.065279, 5.0873175};
13+
1214
static float BNORM_CNN1_VAR[I_F] = {20.676882, 20.414438, 21.339178, 23.554377, 24.701517, 26.168814, 27.535915, 28.564592, 28.969957, 28.859362, 28.364, 28.106115, 28.451006, 29.011675, 29.317871, 29.283287, 29.086376, 28.775879, 28.478878, 28.178596, 27.520367, 27.02571, 26.705261, 26.066133, 25.702063, 25.294205, 24.99427, 24.72525, 24.640686, 24.572824, 24.549019, 24.534306, 24.495142, 24.296028, 23.391523, 23.53733, 24.648394, 24.782791, 24.837341, 24.872974, 24.787487, 24.739609, 24.640944, 24.550161, 24.56518, 24.665766, 24.818687, 25.024239, 25.244087, 25.361431, 25.36517, 25.276, 25.320591, 25.33135, 25.398449, 25.40862, 25.353382, 25.344984, 25.372326, 25.441729, 25.504103, 25.42918, 25.153177, 25.618937, 25.690516, 25.687012, 25.684065, 25.624876, 25.630314, 25.655672, 25.612, 25.64374, 25.61967, 25.579008, 25.566566, 25.550398, 25.508354, 25.44107, 25.65704, 25.880789};
1315

1416
static float CNN1_BIAS[O_F] = {1.300191, -0.779042, 0.12877105, -2.037354, -0.9435298, -2.5254133, -1.5369109, 1.0702158, 3.0273507, -0.7386386, 1.7626207, -0.58604527, 2.106561, -1.5296501, -1.2816843, -0.6733964, -0.7367651, 1.7345444, -0.05970213, 1.3843738, 1.6244868, 1.8320789, 0.5069751, 0.62708217, 2.003625, 0.6377755, -1.9624366, 0.63646936, 2.0180166, -0.25709784, 2.255273, 0.38705197, 2.169011, 1.1031395, -1.5280739, -0.76651794, -1.082367, 1.9920447, -2.0233238, 1.6334648, 0.8347151, -1.0188212, -1.0827045, -1.7733161, -0.45355535, 1.0320476, -2.4755456, 0.56821054, 0.75964713, 0.06325834, -1.3351326, 0.947213, -2.7096841, -0.2805782, 0.5053991, 0.33857206, -1.5643214, -1.9894447, -1.7913662, 1.914957, 2.0200634, 1.3860161, -1.8900268, -1.3497273, 1.0761157, 1.7763672, -0.65279585, 1.9923992, -0.5291019, 1.5516317, 1.676236, -2.257369, 2.3529193, 1.0615077, 0.51217747, -2.0283973, -1.3906946, -0.005759739, -1.014273, -1.7721412, -1.7974949, -1.7176688, -2.0392191, 0.87384427, 0.7713191, -0.21536666, 1.7571598, -1.9713863, -1.7524513, 2.239532, -2.0877903, 0.67606, -1.2471493, -1.3209369, -1.5323899, -0.86020976, -0.19431618, -0.29289863, -1.1362463, -0.27989164, 0.9613982, -0.9303186, 0.80145055, -1.5114999, -0.53770745, -1.0007375, 1.2348008, -0.7031456, 1.9639963, -1.8767631, 0.54409057, 1.5087916, -1.4712379, -0.93829006, -1.3319114, -2.3127968, -1.7907258, -0.407395, 0.6168143, 0.82548475, 0.8482644, -0.8501224, 1.6107091, 2.134612, -1.0925547, 1.7021487, -2.1517541, 1.1478122, -1.8185, -0.60170245, -1.5785563, -1.4434805, -0.84806854, 2.124417, -1.7992799, 1.0451778, 2.0745783, -0.49988687, 0.44799557, -2.2847998, 1.6671776, 0.2003777, 0.40834093, 1.8844966, -1.5649115, -2.078135, 1.4730648, 2.2270346, 1.5995127, 0.3647449, 1.0695671, -1.8636547, -1.0998644, -2.5635734, 0.017660717, -1.9947071, 1.5408367, 0.90499216, 0.91611373, 1.8415768, -0.934941, 0.99624336, -1.7224451, -2.117806, -1.8641261, -0.27287853, -1.3290448, 0.68198127, 1.0655302, 1.3089589, 1.5892235, 1.9483236, 2.671394, 2.0462663, 1.0995138, -0.8231849, 0.43930018, 1.6719505, -1.0527816, -0.652467, 1.2310717, 1.6923088, -0.35652065, -1.027436, -0.22795796, 0.109296955, -0.6748324, 0.5586138, -1.180884, -2.18752, 1.1215136, 1.6672724, -1.4353535, 2.0336459, -2.7333302, -1.4794307, -1.70236, -0.1282032, 1.5664086, -2.98293, -2.2209318, 0.8168484, -0.78808755, -1.5923033, 0.96671927, -1.9981681, -1.8891708, -0.71883076, 0.9872053, 1.8431993, 1.4599383, 1.461618, -2.6698012, 1.7014012, -0.90234613, -3.1844337, -2.0179415, 1.091565, -1.1652547, -1.0443002, -0.048881833, -1.5857307, 1.8649789, 1.9054865, 0.9829317, -0.8272472, 2.2824135, 0.95920223, -1.2486432, -0.44787234, -1.4609123, 1.9108136, 1.3920923, 1.0286125, 1.6861342, 1.452765, 1.1245193, -1.054227, 0.6595548, 1.0182625, -1.7440542, 0.9177331, -0.3515849, -1.4495276, 2.107017, 1.6213392, -1.5320052, -0.80822104, -2.0555818, -1.9608325, -1.7011721, -1.877503, -0.7057842, -2.0763748, -1.5858908, 1.2160226, -2.4326186, 1.6956238, -1.5471596, -1.7387867, 2.1300697, 0.18794368, -0.4583019, -3.1489387, 1.6009101, 2.1964102, -0.6978634, 1.2681695, 0.4626042, -1.461521, -0.14139582, 1.1862375, 1.5006682, -1.5338326, 1.7466931, -1.3882873, 0.36158136, 0.11920393, -1.525045, 0.9772708, 2.0064225, 0.35159564, -0.9488687, 1.5600407, 1.3372803, -2.4081395, 1.7896626, 0.001687233, 1.5024955, 1.0004863, 1.2808807, -0.23564178, 1.6312653, -1.4802891, 2.3526027, -2.7684214, 1.371979, 0.967034, 1.6903735, -1.7038316, 1.52379, -1.1193563, -2.096888, 2.0172017, 1.8445942, 0.56659895, 1.669906, -1.4930186, 2.149227, 1.5583338, -2.212712, -1.1139998, -1.9200424, -1.4756505, 0.29662177, 1.1342987, -1.1442064, -1.3027734, 1.4665192, 1.3889705, 0.76326215, -1.6858389, -2.646514, 1.2635381, 1.2053379, -0.15261778, 1.3416804, -0.57838917, -0.14112566, -2.0271254, 0.36985314, -1.3416636, 1.1499223, -1.9333936, 0.9676799, 1.846942, -1.823447, -0.44626495, -0.8492061, -0.94446206, 2.1404736, -0.58360445, 2.5060663, -1.4616301, -2.201911, -2.0038357, 1.6050764, -0.68346953, -0.048071913, 1.2584175, 1.2293136, 0.7809569, 1.0426509, -2.1945558, -0.80175585, -2.3009183, -0.6103452, -0.2072521, -0.21979602, 1.4457221, 0.12900288, -0.6473033, 1.5531462, -0.5543701, -0.8842336, 0.6678211, 0.33419546, 1.430383, 1.8887885, -3.1204836, 0.9718442, 1.9787605, -2.7728877, -1.8637267, -1.8636626, -0.34137356, -2.4106872, -1.9265975, -1.6192602, 1.2435093, -1.6963999, 1.5720334, -2.2001069, -0.8705381, 0.937398, -2.1680825, 0.34210345, 0.5386563, 1.7298652, 1.496569, -0.34388074, -2.686154, -2.1040716, 1.0595173, -1.2657964, 0.68641293, 0.40762904, 0.88818675, 1.066454, -0.6460059};

0 commit comments

Comments
 (0)