Skip to content

Commit 528ce8d

Browse files
committed
fix a bug
1 parent 7e17879 commit 528ce8d

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

src/cuda/conv.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <conv.cuh>
1+
#include <conv.cuh>
22

33
#include <thrust/copy.h>
44
#include <thrust/device_vector.h>
@@ -238,7 +238,7 @@ void operator_d_conv(
238238
// dL/d_col = F^T * dL/dY
239239
std::vector<int> dl_dcol_shape{batch_size, channel_in * kernel_h * kernel_w,
240240
height_col * width_col};
241-
INIT_TEMP(temp, "dl_dcol", dl_df_shape);
241+
INIT_TEMP(temp, "dl_dcol", dl_dcol_shape);
242242
operator_matmul(temp["filters_t"].get(), outputs_grad, temp["dl_dcol"].get(),
243243
1); // broadcast param 1
244244

src/cuda/layer.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <storage.cuh>
4+
#include <utils.cuh>
45

56
#include <memory>
67
#include <vector>

0 commit comments

Comments
 (0)