10
10
11
11
int phon_pred_lr_cnn (float * output_signal , float * input_signal ,
12
12
unsigned in_time , unsigned in_channels ,
13
- float * mean , float * var , unsigned affine , float * gamma , float * beta , unsigned in_place ,
13
+ const float * const mean , const float * const var ,
14
+ unsigned affine , float * gamma , float * beta , unsigned in_place ,
14
15
unsigned cnn_hidden , unsigned cnn_padding , unsigned cnn_kernel_size ,
15
- const void * cnn_params , unsigned cnn_stride , int cnn_activation ) {
16
+ const void * cnn_params , unsigned cnn_stride , unsigned cnn_activation ) {
16
17
17
18
unsigned out_time = in_time - cnn_kernel_size + 2 * cnn_padding + 1 ;
18
19
if (in_place ) {
@@ -44,29 +45,31 @@ int phon_pred_lr_cnn(float* output_signal, float* input_signal,
44
45
45
46
int phon_pred_depth_point_lr_cnn (float * output_signal , float * input_signal ,
46
47
unsigned in_time , unsigned in_channels ,
47
- float * mean , float * var , unsigned affine , float * gamma , float * beta , unsigned in_place ,
48
- unsigned depth_cnn_hidden , unsigned depth_cnn_padding , unsigned depth_cnn_kernel_size ,
49
- const void * depth_cnn_params , unsigned depth_cnn_stride , int depth_cnn_activation ,
48
+ const float * const mean , const float * const var ,
49
+ unsigned affine , const float * const gamma , const float * const beta , unsigned in_place ,
50
+ unsigned depth_cnn_padding , unsigned depth_cnn_kernel_size ,
51
+ const void * depth_cnn_params , unsigned depth_cnn_stride , unsigned depth_cnn_activation ,
50
52
unsigned point_cnn_hidden , unsigned point_cnn_padding , unsigned point_cnn_kernel_size ,
51
- const void * point_cnn_params , unsigned point_cnn_stride , int point_cnn_activation ,
52
- unsigned pool_padding , unsigned pool_kernel_size , unsigned pool_stride , int pool_activation ) {
53
+ const void * point_cnn_params , unsigned point_cnn_stride , unsigned point_cnn_activation ,
54
+ unsigned pool_padding , unsigned pool_kernel_size , unsigned pool_stride , unsigned pool_activation ) {
53
55
54
56
// Activation
55
- unsigned out_time ;
57
+
56
58
float * act_out = (float * )malloc (in_time * (in_channels >> 1 ) * sizeof (float ));
57
59
semi_sigmoid_tanh (act_out , input_signal , in_time , in_channels );
58
60
59
61
in_channels >>= 1 ;
60
62
float * depth_out ;
63
+ unsigned out_time = in_time - depth_cnn_kernel_size + 2 * depth_cnn_padding + 1 ;
61
64
if (in_place ) {
62
65
// Norm
63
66
batchnorm1d (0 , act_out ,
64
67
in_time , in_channels ,
65
- mean , var , affine , gamma , beta ,
68
+ mean , var ,
69
+ affine , gamma , beta ,
66
70
in_place , 0.00001 );
67
71
// Depth CNN
68
- out_time = in_time - depth_cnn_kernel_size + 2 * depth_cnn_padding + 1 ;
69
- depth_out = (float * )malloc (out_time * depth_cnn_hidden * sizeof (float ));
72
+ depth_out = (float * )malloc (out_time * in_channels * sizeof (float ));
70
73
conv1d_depth (depth_out , out_time , act_out ,
71
74
in_time , in_channels , depth_cnn_padding , depth_cnn_kernel_size ,
72
75
depth_cnn_params , depth_cnn_stride , depth_cnn_activation );
@@ -77,12 +80,12 @@ int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal,
77
80
float * norm_out = (float * )malloc (in_time * in_channels * sizeof (float ));
78
81
batchnorm1d (norm_out , act_out ,
79
82
in_time , in_channels ,
80
- mean , var , affine , gamma , beta ,
83
+ mean , var ,
84
+ affine , gamma , beta ,
81
85
in_place , 0.00001 );
82
86
free (act_out );
83
87
// Depth CNN
84
- out_time = in_time - depth_cnn_kernel_size + 2 * depth_cnn_padding + 1 ;
85
- depth_out = (float * )malloc (out_time * depth_cnn_hidden * sizeof (float ));
88
+ depth_out = (float * )malloc (out_time * in_channels * sizeof (float ));
86
89
conv1d_depth (depth_out , out_time , norm_out ,
87
90
in_time , in_channels , depth_cnn_padding , depth_cnn_kernel_size ,
88
91
depth_cnn_params , depth_cnn_stride , depth_cnn_activation );
@@ -94,7 +97,7 @@ int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal,
94
97
out_time = in_time - point_cnn_kernel_size + 2 * point_cnn_padding + 1 ;
95
98
float * point_out = (float * )malloc (out_time * point_cnn_hidden * sizeof (float ));
96
99
conv1d_lr (point_out , out_time , point_cnn_hidden , depth_out ,
97
- in_time , depth_cnn_hidden , point_cnn_padding , point_cnn_kernel_size ,
100
+ in_time , in_channels , point_cnn_padding , point_cnn_kernel_size ,
98
101
point_cnn_params , point_cnn_stride , point_cnn_activation );
99
102
free (depth_out );
100
103
0 commit comments