1+ #include < ggml.h>
2+ #include < ggml-cpu.h>
3+ #include < ggml-alloc.h>
4+ #include < ggml-backend.h>
5+ #include < ggml-cpp.h>
6+
7+ #include < cassert>
8+ #include < cmath>
9+ #include < cstdio>
10+ #include < array>
11+ #include < vector>
12+
13+ bool check_equal (const float * result, const float * expected, int64_t n) {
14+ for (int i = 0 ; i < n; i++) {
15+ if (std::abs (result[i] - expected[i]) > 1e-4 ) {
16+ printf (" result[%d] %f != %f expected[%d]\n " , i, result[i], expected[i], i);
17+ return false ;
18+ }
19+ }
20+ return true ;
21+ }
22+
23+ bool test_interpolate (char const * name,
24+ std::array<int64_t , 4 > src_ne, const float * src_data,
25+ std::array<int32_t , 4 > dst_ne, const float * expected,
26+ uint32_t mode) {
27+ ggml_time_init ();
28+
29+ ggml_init_params params {
30+ /* .mem_size =*/ 64 * ggml_tensor_overhead () + ggml_graph_overhead (),
31+ /* .mem_buffer =*/ NULL ,
32+ /* .no_alloc =*/ true
33+ };
34+
35+ ggml_context_ptr ctx_ptr{ggml_init (params)};
36+ ggml_context * ctx = ctx_ptr.get ();
37+ ggml_cgraph * gf = ggml_new_graph (ctx);
38+
39+ // Build graph
40+ ggml_tensor * src = ggml_new_tensor (ctx, GGML_TYPE_F32, 4 , src_ne.data ());
41+ ggml_tensor * res = ggml_interpolate (ctx, src, dst_ne[0 ], dst_ne[1 ], dst_ne[2 ], dst_ne[3 ], mode);
42+ ggml_build_forward_expand (gf, res);
43+
44+ // Create backend & allocate buffers
45+ ggml_backend_ptr backend_ptr{ggml_backend_cpu_init ()};
46+ ggml_backend_t backend = backend_ptr.get ();
47+ ggml_backend_cpu_set_n_threads (backend, 2 );
48+ ggml_backend_buffer_ptr buffer{ggml_backend_alloc_ctx_tensors (ctx, backend)};
49+
50+ // Execute and compare results
51+ ggml_backend_tensor_set (src, src_data, 0 , ggml_nbytes (src));
52+ ggml_backend_graph_compute (backend, gf);
53+
54+ std::vector<float > res_values (ggml_nelements (res));
55+ ggml_backend_tensor_get (res, res_values.data (), 0 , ggml_nbytes (res));
56+
57+ bool passed = check_equal (res_values.data (), expected, ggml_nelements (res));
58+
59+ printf (" ggml_interpolate(%s): %s\n " , name, passed ? " \033 [32mPASSED\033 [0m" : " \033 [31mFAILED\033 [0m" );
60+ return passed;
61+ }
62+
63+ const float input_upscale[] = {
64+ 0 .0f , 1 .0f ,
65+ 2 .0f , 4 .0f
66+ };
67+
68+ const float expected_upscale_x2_nearest[] = {
69+ 0 .0f , 0 .0f , 1 .0f , 1 .0f ,
70+ 0 .0f , 0 .0f , 1 .0f , 1 .0f ,
71+ 2 .0f , 2 .0f , 4 .0f , 4 .0f ,
72+ 2 .0f , 2 .0f , 4 .0f , 4 .0f
73+ };
74+
75+ const float expected_upscale_x2_bilinear[] = {
76+ 0 .0f , 0 .2500f , 0 .7500f , 1 .00f ,
77+ 0 .5f , 0 .8125f , 1 .4375f , 1 .75f ,
78+ 1 .5f , 1 .9375f , 2 .8125f , 3 .25f ,
79+ 2 .0f , 2 .5000f , 3 .5000f , 4 .00f
80+ };
81+
82+ const float expected_upscale_x2_bilinear_align_corners[] = {
83+ 0 .0000f , 0 .3333f , 0 .6667f , 1 .0000f ,
84+ 0 .6667f , 1 .1111f , 1 .5556f , 2 .0000f ,
85+ 1 .3333f , 1 .8889f , 2 .4444f , 3 .0000f ,
86+ 2 .0000f , 2 .6667f , 3 .3333f , 4 .0000f
87+ };
88+
89+ const float expected_upscale_x1_5_bilinear_align_corners[] = {
90+ 0 .0f , 1 .0f ,
91+ 1 .0f , 2 .5f ,
92+ 2 .0f , 4 .0f
93+ };
94+
95+ const float input_downscale[] = {
96+ 0 .0f , -1 .0f , -2 .0f , 0 .0f ,
97+ 1 .0f , 2 .0f , 4 .0f , 4 .0f ,
98+ 2 .0f , 2 .0f , 1 .0f , 1 .0f ,
99+
100+ 1 .0f , 2 .0f , 3 .0f , 4 .0f ,
101+ 2 .0f , 2 .0f , 2 .0f , 2 .0f ,
102+ -2 .0f , 2 .0f , -4 .0f , 4 .0f
103+ };
104+
105+ const float expected_downscale_nearest[] = {
106+ 0 .0f , -2 .0f ,
107+
108+ 1 .0f , 3 .0f
109+ };
110+
111+ const float expected_downscale_bilinear[] = {
112+ 0 .1667f , -0 .3750f , 0 .7500f ,
113+ 1 .7917f , 1 .8750f , 1 .7500f ,
114+
115+ 1 .3750f , 2 .3750f , 3 .3750f ,
116+ -0 .5000f , -0 .2500f , 2 .5000f
117+ };
118+
119+ const float expected_downscale_bilinear_align_corners[] = {
120+ 0 .0f , -1 .5f , 0 .0f ,
121+ 2 .0f , 1 .5f , 1 .0f ,
122+
123+ 1 .0f , 2 .5f , 4 .0f ,
124+ -2 .0f , -1 .0f , 4 .0f
125+ };
126+
127+ int main () {
128+ bool passed = true ;
129+
130+ passed &= test_interpolate (" upscale_x2_nearest" ,
131+ {2 , 2 , 1 , 1 }, input_upscale,
132+ {4 , 4 , 1 , 1 }, expected_upscale_x2_nearest,
133+ GGML_SCALE_MODE_NEAREST);
134+
135+ passed &= test_interpolate (" upscale_x2_bilinear" ,
136+ {2 , 2 , 1 , 1 }, input_upscale,
137+ {4 , 4 , 1 , 1 }, expected_upscale_x2_bilinear,
138+ GGML_SCALE_MODE_BILINEAR);
139+
140+ passed &= test_interpolate (" upscale_x2_bilinear_align_corners" ,
141+ {2 , 2 , 1 , 1 }, input_upscale,
142+ {4 , 4 , 1 , 1 }, expected_upscale_x2_bilinear_align_corners,
143+ GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS);
144+
145+ passed &= test_interpolate (" upscale_x1_5_bilinear_align_corners" ,
146+ {2 , 2 , 1 , 1 }, input_upscale,
147+ {2 , 3 , 1 , 1 }, expected_upscale_x1_5_bilinear_align_corners,
148+ GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS);
149+
150+ passed &= test_interpolate (" downscale_nearest" ,
151+ {4 , 3 , 2 , 1 }, input_downscale,
152+ {2 , 1 , 2 , 1 }, expected_downscale_nearest,
153+ GGML_SCALE_MODE_NEAREST);
154+
155+ passed &= test_interpolate (" downscale_bilinear" ,
156+ {4 , 3 , 2 , 1 }, input_downscale,
157+ {3 , 2 , 2 , 1 }, expected_downscale_bilinear,
158+ GGML_SCALE_MODE_BILINEAR);
159+
160+ passed &= test_interpolate (" downscale_bilinear_align_corners" ,
161+ {4 , 3 , 2 , 1 }, input_downscale,
162+ {3 , 2 , 2 , 1 }, expected_downscale_bilinear_align_corners,
163+ GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS);
164+
165+ return passed ? 0 : 1 ;
166+ }
0 commit comments