Skip to content

Commit 3d80466

Browse files
author
bssrdf
committed
test now passed; for some reason, ggml_conv_2d didn't output correct results
1 parent 02a3cb1 commit 3d80466

File tree

6 files changed

+60
-174
lines changed

6 files changed

+60
-174
lines changed

src/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@ if (GGML_CUDA)
285285
# 61 == integer CUDA intrinsics
286286
# 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
287287
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
288-
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
288+
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;86")
289289
else()
290-
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
290+
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75;86")
291291
#set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work
292292
endif()
293293
endif()

src/ggml-cuda/conv-winograd.cu

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,8 @@ __device__ float f_row1(float *Gw, int j){
386386

387387
__global__ void FX(const float *pInputs, float *pOutputs, int filt_k,
388388
int filt_c, int filt_h, int filt_w){
389+
390+
// assumes CHWK layout
389391
int Inx = threadIdx.x, Iny = threadIdx.y;
390392
int TileX = blockIdx.x, TileY = blockIdx.y;
391393

@@ -725,31 +727,35 @@ static void conv_winograd_stage0_f32_f32_cuda(
725727
cudaStream_t stream) {
726728

727729

728-
int64_t filt_k = src0_ne3;
729-
int64_t filt_c = src0_ne2;
730+
int64_t filt_k = src0_ne0;
731+
int64_t filt_c = src0_ne3;
730732

731-
FX<<<dim3(filt_k/BK, filt_c/BC), dim3(32, BC)>>>(src0, dst, filt_k, filt_c, src0_ne1, src0_ne0);
733+
FX<<<dim3(filt_k/BK, filt_c/BC), dim3(32, BC)>>>(src0, dst, filt_k, filt_c, src0_ne2, src0_ne1);
732734

733735
}
734736

735-
static void conv_winograd_stage1_f16_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y,
737+
static void conv_winograd_stage1_f32_f32_cuda(int tiles_dim_w, int tiles_dim_h, int X, int Y,
736738
int tile_size, int tile_2d_s,
737739
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
738740
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
739741
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
740742
const float * src0, const float * src1, float * dst,
741743
cudaStream_t stream) {
742744

743-
int64_t filt_k = src0_ne3;
745+
int64_t filt_k = src0_ne0;
744746
int64_t in_c = src1_ne2;
745747
int64_t in_h = src1_ne1;
746748
int64_t in_w = src1_ne0;
747-
int64_t filt_c = src1_ne0;
749+
int64_t filt_c = src0_ne3;
748750
int64_t out_c = filt_k;
749751
int64_t out_h = in_h;
750752
int64_t out_w = in_w;
751753
int smem_size = (16*BN*BC + 16*BC*BK)*4;
752754

755+
printf("A %d, %d\n", filt_k, filt_c);
756+
printf("B %d, %d, %d \n", in_c, in_h, in_w);
757+
printf("C %d, %d, %d \n", out_c, out_h, out_w);
758+
753759
Winograd_kernel<<<dim3((tiles_dim_w+X-1)/X, (tiles_dim_h+Y-1)/Y, filt_k/BK), dim3(BN, 8), smem_size>>>(src1, src0, dst,
754760
tiles_dim_w, tiles_dim_h, in_c, in_h, in_w, tile_size, X, Y, filt_k, filt_c, out_c, tile_2d_s, out_h, out_w);
755761
}
@@ -816,8 +822,8 @@ void ggml_cuda_op_winograd_stage1(ggml_backend_cuda_context & ctx, ggml_tensor *
816822
cudaMemcpyToSymbol(access_f_s, aux, 64*sizeof(int));
817823
cudaMemcpyToSymbol(access_s, aux2, 64*sizeof(int));
818824
cudaMemcpyToSymbol(tileid, tid, 64*sizeof(int));
819-
820-
conv_winograd_stage1_f16_f32_cuda(tiles_dim_w, tiles_dim_h, 4, 8,
825+
printf(" %d, %d, %d \n", tiles_dim_w, tiles_dim_h, tile_size);
826+
conv_winograd_stage1_f32_f32_cuda(tiles_dim_w, tiles_dim_h, 4, 8,
821827
tile_size, tile_2d_s,
822828
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
823829
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],

src/ggml-cuda/conv-winograd.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "common.cuh"
22

3-
// #define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256
3+
44
#define BC 8
55
#define BN 32
66
#define BK 64

src/ggml.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7179,12 +7179,12 @@ struct ggml_tensor * ggml_winograd_stage0(
71797179
struct ggml_context * ctx,
71807180
struct ggml_tensor * a) {
71817181
bool is_node = false;
7182-
GGML_ASSERT(a->ne[0] == 3 && a->ne[1] == 3); // kernel should be 3x3
7182+
71837183
if (a->grad) {
71847184
is_node = true;
71857185
}
71867186

7187-
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 4, 4, a->ne[2], a->ne[3]);
7187+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], 4, 4, a->ne[3]);
71887188

71897189
result->op = GGML_OP_WINOGRAD_STAGE0;
71907190
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7208,7 +7208,7 @@ struct ggml_tensor * ggml_winograd_stage1(
72087208

72097209
int OW = b->ne[0];
72107210
int OH = b->ne[1];
7211-
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, OW, OH, a->ne[3] /* OC */, 1);
7211+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, OW, OH, a->ne[0] /* OC */, 1);
72127212

72137213
result->op = GGML_OP_WINOGRAD_STAGE1;
72147214
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7222,14 +7222,14 @@ struct ggml_tensor * ggml_conv_2d_3x3(
72227222
struct ggml_context * ctx,
72237223
struct ggml_tensor * a,
72247224
struct ggml_tensor * b){
7225-
7225+
GGML_ASSERT(a->ne[0] == 3 && a->ne[1] == 3); // kernel should be 3x3
72267226
GGML_ASSERT(b->ne[3] == 1); // only works for 1 input image
72277227
GGML_ASSERT(b->ne[2] == a->ne[2]); // number of channels must match
72287228
if(a->ne[3] % 64 != 0 || a->ne[2] % 8 != 0) // only works for the number of filters is a multiple of 64
72297229
return ggml_conv_2d(ctx, a, b, 1, 1, 1, 1, 1, 1); // and the number of channels is a multiple of 8
72307230

7231-
7232-
struct ggml_tensor* W = ggml_winograd_stage0(ctx, a);
7231+
struct ggml_tensor* ra = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 3, 0)); // [N, OC, OH, OW]
7232+
struct ggml_tensor* W = ggml_winograd_stage0(ctx, ra);
72337233
struct ggml_tensor * result = ggml_winograd_stage1(ctx, W, b);
72347234

72357235
return result;

tests/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,15 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml)
408408
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
409409
set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
410410

411+
#
412+
# test-conv2d-wino
413+
414+
set(TEST_TARGET test-conv2d-winograd)
415+
add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
416+
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
417+
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
418+
set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
419+
411420

412421
#
413422
# test-mul-mat

tests/test-conv2d-winograd.cpp

Lines changed: 28 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ struct test_model {
3636

3737
void load_model(test_model & model, bool use_gpu = false) {
3838
// create data
39-
int KW = 3, KH = 3, IC = 10, OC = 10;
40-
int IW = 8, IH = 6, N = 1;
39+
int KW = 3, KH = 3, IC = 32, OC = 64;
40+
int IW = 28, IH = 40, N = 1;
4141

4242
// Initialize adata
4343
std::vector<float> adata(KW * KH * IC * OC);
@@ -157,16 +157,21 @@ struct ggml_cgraph * build_graph(const test_model& model) {
157157
int d0 = 1;
158158
int d1 = 1;
159159

160-
// split conv2d in fundamental methods for test unit
161-
struct ggml_tensor* im2col_0 = ggml_im2col(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16);
162-
ggml_set_name(im2col_0, "im2col_res");
163-
ggml_build_forward_expand(gf, im2col_0);
160+
164161

165162
// recalculate for avoid fragmentation
166163
struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
167164
ggml_set_name(conv2d_res, "conv2d_res");
168165
ggml_build_forward_expand(gf, conv2d_res);
166+
int64_t *ne = conv2d_res->ne;
167+
printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
169168

169+
170+
struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b);
171+
ggml_set_name(wino_res, "wino_res");
172+
ggml_build_forward_expand(gf, wino_res);
173+
ne = wino_res->ne;
174+
printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
170175
ggml_free(ctx0);
171176
return gf;
172177
}
@@ -218,173 +223,39 @@ int main(void)
218223

219224
struct ggml_cgraph * gf_res = compute_graph(model, allocr);
220225

221-
struct ggml_tensor * im2col_res = NULL;
226+
struct ggml_tensor * wino_res = NULL;
222227
struct ggml_tensor * conv2d_res = NULL;
223228

224229
for(int i = 0; i < ggml_graph_n_nodes(gf_res); ++i) {
225-
if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), "im2col_res") == 0) {
226-
im2col_res = ggml_graph_node(gf_res, i);
230+
if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), "wino_res") == 0) {
231+
wino_res = ggml_graph_node(gf_res, i);
227232
} else if(strcmp(ggml_get_name(ggml_graph_node(gf_res, i)), "conv2d_res") == 0) {
228233
conv2d_res = ggml_graph_node(gf_res, i);
229234
}
230235
}
231236

232-
std::vector<uint16_t> im2col_data(ggml_nelements(im2col_res));
237+
std::vector<float> wino_data(ggml_nelements(wino_res));
233238
std::vector<float> conv2d_data(ggml_nelements(conv2d_res));
234239

235-
ggml_backend_tensor_get(im2col_res, im2col_data.data(), 0, ggml_nbytes(im2col_res));
240+
ggml_backend_tensor_get(wino_res, wino_data.data(), 0, ggml_nbytes(wino_res));
236241
ggml_backend_tensor_get(conv2d_res, conv2d_data.data(), 0, ggml_nbytes(conv2d_res));
237242

238-
const int n_conv2d_test = 480;
239-
const int n_im2col_test = 4320;
240-
241-
float expected_conv2d [n_conv2d_test] = {
242-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
243-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
244-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
245-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
246-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
247-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
248-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
249-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
250-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
251-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
252-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
253-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
254-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
255-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
256-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
257-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
258-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
259-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
260-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
261-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
262-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
263-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
264-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
265-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
266-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
267-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
268-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
269-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
270-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
271-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
272-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
273-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
274-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
275-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
276-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
277-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
278-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
279-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
280-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
281-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
282-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
283-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
284-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
285-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
286-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
287-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
288-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
289-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
290-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
291-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
292-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
293-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
294-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
295-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
296-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f,
297-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
298-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
299-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
300-
225.00f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 337.50f, 225.00f,
301-
150.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 225.00f, 150.00f };
302-
303-
uint16_t expected_im2col[n_conv2d_test] = {
304-
0, 0, 0, 0, 15872, 15872, 0, 15872,
305-
15872, 0, 0, 0, 0, 15872, 15872, 0,
306-
15872, 15872, 0, 0, 0, 0, 15872, 15872,
307-
0, 15872, 15872, 0, 0, 0, 0, 15872,
308-
15872, 0, 15872, 15872, 0, 0, 0, 0,
309-
15872, 15872, 0, 15872, 15872, 0, 0, 0,
310-
0, 15872, 15872, 0, 15872, 15872, 0, 0,
311-
0, 0, 15872, 15872, 0, 15872, 15872, 0,
312-
0, 0, 0, 15872, 15872, 0, 15872, 15872,
313-
0, 0, 0, 0, 15872, 15872, 0, 15872,
314-
15872, 0, 0, 0, 0, 15872, 15872, 0,
315-
15872, 15872, 0, 0, 0, 15872, 15872, 15872,
316-
15872, 15872, 15872, 0, 0, 0, 15872, 15872,
317-
15872, 15872, 15872, 15872, 0, 0, 0, 15872,
318-
15872, 15872, 15872, 15872, 15872, 0, 0, 0,
319-
15872, 15872, 15872, 15872, 15872, 15872, 0, 0,
320-
0, 15872, 15872, 15872, 15872, 15872, 15872, 0,
321-
0, 0, 15872, 15872, 15872, 15872, 15872, 15872,
322-
0, 0, 0, 15872, 15872, 15872, 15872, 15872,
323-
15872, 0, 0, 0, 15872, 15872, 15872, 15872,
324-
15872, 15872, 0, 0, 0, 15872, 15872, 15872,
325-
15872, 15872, 15872, 0, 0, 0, 15872, 15872,
326-
15872, 15872, 15872, 15872, 0, 0, 0, 15872,
327-
15872, 15872, 15872, 15872, 15872, 0, 0, 0,
328-
15872, 15872, 15872, 15872, 15872, 15872, 0, 0,
329-
0, 15872, 15872, 15872, 15872, 15872, 15872, 0,
330-
0, 0, 15872, 15872, 15872, 15872, 15872, 15872,
331-
0, 0, 0, 15872, 15872, 15872, 15872, 15872,
332-
15872, 0, 0, 0, 15872, 15872, 15872, 15872,
333-
15872, 15872, 0, 0, 0, 15872, 15872, 15872,
334-
15872, 15872, 15872, 0, 0, 0, 15872, 15872,
335-
15872, 15872, 15872, 15872, 0, 0, 0, 15872,
336-
15872, 15872, 15872, 15872, 15872, 0, 0, 0,
337-
15872, 15872, 15872, 15872, 15872, 15872, 0, 0,
338-
0, 15872, 15872, 15872, 15872, 15872, 15872, 0,
339-
0, 0, 15872, 15872, 15872, 15872, 15872, 15872,
340-
0, 0, 0, 15872, 15872, 15872, 15872, 15872,
341-
15872, 0, 0, 0, 15872, 15872, 15872, 15872,
342-
15872, 15872, 0, 0, 0, 15872, 15872, 15872,
343-
15872, 15872, 15872, 0, 0, 0, 15872, 15872,
344-
15872, 15872, 15872, 15872, 0, 0, 0, 15872,
345-
15872, 15872, 15872, 15872, 15872, 0, 0, 0,
346-
15872, 15872, 15872, 15872, 15872, 15872, 0, 0,
347-
0, 15872, 15872, 15872, 15872, 15872, 15872, 0,
348-
0, 0, 15872, 15872, 15872, 15872, 15872, 15872,
349-
0, 0, 0, 15872, 15872, 15872, 15872, 15872,
350-
15872, 0, 0, 0, 15872, 15872, 15872, 15872,
351-
15872, 15872, 0, 0, 0, 15872, 15872, 15872,
352-
15872, 15872, 15872, 0, 0, 0, 15872, 15872,
353-
15872, 15872, 15872, 15872, 0, 0, 0, 15872,
354-
15872, 15872, 15872, 15872, 15872, 0, 0, 0,
355-
15872, 15872, 15872, 15872, 15872, 15872, 0, 0,
356-
0, 15872, 15872, 15872, 15872, 15872, 15872, 0,
357-
0, 0, 15872, 15872, 15872, 15872, 15872, 15872,
358-
0, 0, 0, 15872, 15872, 15872, 15872, 15872,
359-
15872, 0, 0, 0, 15872, 15872, 15872, 15872,
360-
15872, 15872, 0, 0, 0, 15872, 15872, 15872,
361-
15872, 15872, 15872, 0, 0, 0, 15872, 15872,
362-
15872, 15872, 15872, 15872, 0, 0, 0, 15872,
363-
15872, 15872, 15872, 15872, 15872, 0, 0, 0
364-
};
365-
366-
printf("\nPerforming test:\n");
243+
244+
printf("\nPerforming test:\n");
367245

368246
bool passed = true;
369-
for(int i = 0; i < n_conv2d_test; i++) {
370-
if(
371-
im2col_data[i] != expected_im2col[i]) {
372-
passed = false;
373-
break;
374-
}
375-
}
376-
377-
printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
378-
379-
passed = true;
380-
for(int i = 0; i < n_conv2d_test; i++) {
381-
if(conv2d_data[i] != expected_conv2d[i]) {
382-
passed = false;
383-
break;
384-
}
247+
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
248+
for(int i = 0; i < 3*28; i++) {
249+
float diff = fabs(conv2d_data[i] - wino_data[i]);
250+
// if(diff > 1.e-4) {
251+
printf("(%f, %f, %f, %d) \n",
252+
conv2d_data[i],
253+
wino_data[i], diff, i);
254+
// break;
255+
// }
385256
}
386257

387-
printf("ggml_conv2d (%d): %s\n", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
258+
388259

389260
ggml_free(model.ctx);
390261

0 commit comments

Comments
 (0)