Skip to content

Commit a6d8836

Browse files
add epsilon to variance to avoid division by zero
1 parent c280665 commit a6d8836

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

c_reference/src/conv1d.c

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,19 @@ int AvgPool1D(float *output_signal, unsigned out_T, const float *input_signal, u
212212

213213
int BatchNorm1d(float* output_signal, float* input_signal, unsigned in_T, unsigned in_channels,
214214
float* mean, float* var, unsigned affine, float* gamma , float * beta, unsigned in_place){
215+
float eps = 0.00001;
215216
if(affine){
216217
if(in_place){
217218
for(int t = 0; t < in_T ; t++){
218219
for(int d = 0 ; d < in_channels ; d++){
219-
input_signal[t * in_channels + d] = gamma[d]*((input_signal[t * in_channels + d] - mean[d])/sqrt(var[d])) + beta[d];
220+
input_signal[t * in_channels + d] = gamma[d]*((input_signal[t * in_channels + d] - mean[d])/sqrt(var[d] + eps)) + beta[d];
220221
}
221222
}
222223
}
223224
else{
224225
for(int t = 0; t < in_T ; t++){
225226
for(int d = 0 ; d < in_channels ; d++){
226-
output_signal[t * in_channels + d] = gamma[d]*((input_signal[t * in_channels + d] - mean[d])/sqrt(var[d])) + beta[d];
227+
output_signal[t * in_channels + d] = gamma[d]*((input_signal[t * in_channels + d] - mean[d])/sqrt(var[d] + eps)) + beta[d];
227228
}
228229
}
229230
}
@@ -232,14 +233,14 @@ int BatchNorm1d(float* output_signal, float* input_signal, unsigned in_T, unsign
232233
if(in_place){
233234
for(int t = 0; t < in_T ; t++){
234235
for(int d = 0 ; d < in_channels ; d++){
235-
input_signal[t * in_channels + d] = ((input_signal[t * in_channels + d] - mean[d])/sqrt(var[d]));
236+
input_signal[t * in_channels + d] = ((input_signal[t * in_channels + d] - mean[d])/sqrt(var[d] + eps));
236237
}
237238
}
238239
}
239240
else{
240241
for(int t = 0; t < in_T ; t++){
241242
for(int d = 0 ; d < in_channels ; d++){
242-
output_signal[t * in_channels + d] = ((input_signal[t * in_channels + d] - mean[d])/sqrt(var[d]));
243+
output_signal[t * in_channels + d] = ((input_signal[t * in_channels + d] - mean[d])/sqrt(var[d] + eps));
243244
}
244245
}
245246
}

0 commit comments

Comments
 (0)