@@ -5455,6 +5455,121 @@ struct test_pad_reflect_1d : public test_case {
54555455 }
54565456};
54575457
5458+ // GGML_OP_WIN_PART
5459+ struct test_win_part : public test_case {
5460+ const ggml_type type;
5461+ const std::array<int64_t , 4 > ne_a; // [C, W, H, B]
5462+ const int w; // window size
5463+ const bool v; // view (non-contiguous input)
5464+
5465+ std::string vars () override {
5466+ return VARS_TO_STR4 (type, ne_a, w, v);
5467+ }
5468+
5469+ test_win_part (ggml_type type = GGML_TYPE_F32,
5470+ std::array<int64_t , 4 > ne_a = {64 , 14 , 14 , 2 },
5471+ int w = 7 ,
5472+ bool v = false )
5473+ : type(type), ne_a(ne_a), w(w), v(v) {}
5474+
5475+ ggml_tensor * build_graph (ggml_context * ctx) override {
5476+ ggml_tensor * a;
5477+ if (v) {
5478+ auto ne = ne_a; ne[0 ] *= 2 ; ne[1 ] *= 2 ;
5479+ a = ggml_new_tensor (ctx, type, 4 , ne.data ());
5480+ ggml_set_name (a, " a" );
5481+
5482+ a = ggml_view_4d (ctx, a, ne_a[0 ], ne_a[1 ], ne_a[2 ], ne_a[3 ],
5483+ a->nb [1 ], a->nb [2 ], a->nb [3 ], 0 );
5484+ ggml_set_name (a, " view_of_a" );
5485+ } else {
5486+ a = ggml_new_tensor (ctx, type, 4 , ne_a.data ());
5487+ ggml_set_name (a, " a" );
5488+ }
5489+
5490+ ggml_tensor * out = ggml_win_part (ctx, a, w);
5491+ ggml_set_name (out, " out" );
5492+
5493+ return out;
5494+ }
5495+ };
5496+
5497+ // GGML_OP_WIN_UNPART
5498+ struct test_win_unpart : public test_case {
5499+ const ggml_type type;
5500+ const std::array<int64_t , 4 > ne_a; // [C, w, w, NPX*NPY*B]
5501+ const int w0; // output width
5502+ const int h0; // output height
5503+ const int w; // window size
5504+
5505+ std::string vars () override {
5506+ return VARS_TO_STR5 (type, ne_a, w0, h0, w);
5507+ }
5508+
5509+ test_win_unpart (ggml_type type = GGML_TYPE_F32,
5510+ std::array<int64_t , 4 > ne_a = {64 , 7 , 7 , 8 },
5511+ int w0 = 14 , int h0 = 14 ,
5512+ int w = 7 )
5513+ : type(type), ne_a(ne_a), w0(w0), h0(h0), w(w) {}
5514+
5515+ ggml_tensor * build_graph (ggml_context * ctx) override {
5516+ ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne_a.data ());
5517+ ggml_set_name (a, " a" );
5518+
5519+ ggml_tensor * out = ggml_win_unpart (ctx, a, w0, h0, w);
5520+ ggml_set_name (out, " out" );
5521+
5522+ return out;
5523+ }
5524+ };
5525+
5526+ // GGML_OP_GET_REL_POS
5527+ struct test_get_rel_pos : public test_case {
5528+ const ggml_type type;
5529+ const int C; // channels
5530+ const int qh; // query height
5531+ const int kh; // key height
5532+ const bool v; // view (non-contiguous input)
5533+
5534+ std::string vars () override {
5535+ return VARS_TO_STR5 (type, C, qh, kh, v);
5536+ }
5537+
5538+ test_get_rel_pos (ggml_type type = GGML_TYPE_F32,
5539+ int C = 64 ,
5540+ int qh = 7 ,
5541+ int kh = 7 ,
5542+ bool v = false )
5543+ : type(type), C(C), qh(qh), kh(kh), v(v) {}
5544+
5545+ ggml_tensor * build_graph (ggml_context * ctx) override {
5546+ // Input tensor has relative position embeddings table
5547+ // Shape: [C, 2*max(qh,kh)-1, 1, 1]
5548+ const int64_t ne_a[4 ] = {C, 2 *std::max (qh, kh) - 1 , 1 , 1 };
5549+
5550+ ggml_tensor * a;
5551+ if (v) {
5552+ // Create larger tensor and view into it (non-contiguous)
5553+ int64_t ne_large[4 ] = {C * 2 , 2 *std::max (qh, kh) - 1 , 1 , 1 };
5554+ a = ggml_new_tensor (ctx, type, 4 , ne_large);
5555+ ggml_set_name (a, " a" );
5556+
5557+ a = ggml_view_4d (ctx, a, C, 2 *std::max (qh, kh) - 1 , 1 , 1 ,
5558+ a->nb [1 ], a->nb [2 ], a->nb [3 ], 0 );
5559+ ggml_set_name (a, " view_of_a" );
5560+ } else {
5561+ a = ggml_new_tensor (ctx, type, 4 , ne_a);
5562+ ggml_set_name (a, " a" );
5563+ }
5564+
5565+ // Output shape: [C, kh, qh, 1]
5566+ ggml_tensor * out = ggml_get_rel_pos (ctx, a, qh, kh);
5567+ ggml_set_name (out, " out" );
5568+
5569+ return out;
5570+ }
5571+ };
5572+
54585573// GGML_OP_ROLL
54595574struct test_roll : public test_case {
54605575 const int shift0;
@@ -7326,6 +7441,53 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
73267441 test_cases.emplace_back (new test_pad_ext ());
73277442 test_cases.emplace_back (new test_pad_reflect_1d ());
73287443 test_cases.emplace_back (new test_pad_reflect_1d (GGML_TYPE_F32, {3000 , 384 , 4 , 1 }));
7444+
7445+ // Window partition tests
7446+ for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7447+ for (bool v : {false , true }) {
7448+ // Exact division: 14x14 -> 2x2 windows of 7x7
7449+ test_cases.emplace_back (new test_win_part (type, {64 , 14 , 14 , 2 }, 7 , v));
7450+ // With padding: 15x15 -> 3x3 windows of 7x7 (padded)
7451+ test_cases.emplace_back (new test_win_part (type, {64 , 15 , 15 , 2 }, 7 , v));
7452+ // Single window: 7x7 -> 1x1 windows of 7x7
7453+ test_cases.emplace_back (new test_win_part (type, {64 , 7 , 7 , 1 }, 7 , v));
7454+ // Larger: 28x28 -> 4x4 windows of 7x7
7455+ test_cases.emplace_back (new test_win_part (type, {128 , 28 , 28 , 4 }, 7 , v));
7456+ // Window size 8: 16x16 -> 2x2 windows of 8x8
7457+ test_cases.emplace_back (new test_win_part (type, {96 , 16 , 16 , 1 }, 8 , v));
7458+ }
7459+ }
7460+
7461+ // Window unpartition tests (inverse of partition)
7462+ for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7463+ // Exact division: 2x2 windows of 7x7 -> 14x14
7464+ test_cases.emplace_back (new test_win_unpart (type, {64 , 7 , 7 , 4 }, 14 , 14 , 7 ));
7465+ // With padding: 3x3 windows of 7x7 -> 15x15
7466+ test_cases.emplace_back (new test_win_unpart (type, {64 , 7 , 7 , 9 }, 15 , 15 , 7 ));
7467+ // Single window: 1x1 windows of 7x7 -> 7x7
7468+ test_cases.emplace_back (new test_win_unpart (type, {64 , 7 , 7 , 1 }, 7 , 7 , 7 ));
7469+ // Larger: 4x4 windows of 7x7 -> 28x28
7470+ test_cases.emplace_back (new test_win_unpart (type, {128 , 7 , 7 , 16 }, 28 , 28 , 7 ));
7471+ // Window size 8: 2x2 windows of 8x8 -> 16x16
7472+ test_cases.emplace_back (new test_win_unpart (type, {96 , 8 , 8 , 4 }, 16 , 16 , 8 ));
7473+ }
7474+
7475+ // Relative position embedding tests (used in SAM)
7476+ for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7477+ for (bool v : {false , true }) {
7478+ // Square small: 3x3 attention
7479+ test_cases.emplace_back (new test_get_rel_pos (type, 5 , 3 , 3 , v));
7480+ // Square medium: 7x7 attention (typical SAM)
7481+ test_cases.emplace_back (new test_get_rel_pos (type, 13 , 7 , 7 , v));
7482+ // Square large: 14x14 attention
7483+ test_cases.emplace_back (new test_get_rel_pos (type, 27 , 14 , 14 , v));
7484+ // Rectangular: 14x7 attention
7485+ test_cases.emplace_back (new test_get_rel_pos (type, 27 , 14 , 7 , v));
7486+ // Edge case: 1x1 attention (minimum)
7487+ test_cases.emplace_back (new test_get_rel_pos (type, 1 , 1 , 1 , v));
7488+ }
7489+ }
7490+
73297491 test_cases.emplace_back (new test_roll ());
73307492 test_cases.emplace_back (new test_arange ());
73317493 test_cases.emplace_back (new test_timestep_embedding ());
0 commit comments