Skip to content

Commit 579b6be

Browse files
Fixed matmul
1 parent 824955e commit 579b6be

File tree

4 files changed

+3
-40
lines changed

4 files changed

+3
-40
lines changed

c_reference/include/conv_utils.h

Lines changed: 0 additions & 11 deletions
This file was deleted.

c_reference/src/conv1d.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ int conv1d_lr(float *output_signal, unsigned out_time, unsigned out_channels, co
1313
const ConvLayers_LR_Params* tparams= (ConvLayers_LR_Params*) params;
1414

1515
float* tempW = (float*)malloc(out_channels * in_channels * kernel_size * sizeof(float));
16-
matmul(tempW, tparams->W1, tparams->W2, tparams->rank, out_channels, in_channels * kernel_size);
16+
matmul(tparams->W1, tparams->W2, out_channels, tparams->rank, in_channels * kernel_size, 0, 1.0, tempW);
1717
// Perform the Convolution
1818
for (int t = 0; t < out_time; t++) {
1919
for (int co = 0; co < out_channels; co++) {
@@ -48,7 +48,7 @@ int conv1d_depth_lr(float *output_signal, unsigned out_time, const float *input_
4848
const ConvLayers_LR_Params* tparams= (ConvLayers_LR_Params*) params;
4949

5050
float* tempW = (float*)malloc(in_channels * kernel_size * sizeof(float));
51-
matmul(tempW, tparams->W1, tparams->W2, tparams->rank, in_channels, kernel_size);
51+
matmul(tparams->W1, tparams->W2, in_channels, tparams->rank, kernel_size, 0, 1.0, tempW);
5252
// Perform the Convolution
5353
for (int t = 0; t < out_time; t++) {
5454
for (int ci = 0; ci < in_channels; ci++) {

c_reference/src/conv_utils.c

Lines changed: 0 additions & 26 deletions
This file was deleted.

c_reference/src/utils.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void matmul(const float* const matA, const float* const matB,
8080
float sum = 0;
8181
for(int k = 0; k < ncommon; k++)
8282
sum += (matA[row * ncommon + k] * matB[k * ncols + col]);
83-
out[row * ncols + col] = alpha * out[row * ncols + col] + beta * sum;
83+
ret[row * ncols + col] = alpha * ret[row * ncols + col] + beta * sum;
8484
}
8585
}
8686
}

0 commit comments

Comments
 (0)