@@ -5518,6 +5518,121 @@ struct test_pad_reflect_1d : public test_case {
55185518 }
55195519};
55205520
5521+ // GGML_OP_WIN_PART
5522+ struct test_win_part : public test_case {
5523+ const ggml_type type;
5524+ const std::array<int64_t , 4 > ne_a; // [C, W, H, B]
5525+ const int w; // window size
5526+ const bool v; // view (non-contiguous input)
5527+
5528+ std::string vars () override {
5529+ return VARS_TO_STR4 (type, ne_a, w, v);
5530+ }
5531+
5532+ test_win_part (ggml_type type = GGML_TYPE_F32,
5533+ std::array<int64_t , 4 > ne_a = {64 , 14 , 14 , 2 },
5534+ int w = 7 ,
5535+ bool v = false )
5536+ : type(type), ne_a(ne_a), w(w), v(v) {}
5537+
5538+ ggml_tensor * build_graph (ggml_context * ctx) override {
5539+ ggml_tensor * a;
5540+ if (v) {
5541+ auto ne = ne_a; ne[0 ] *= 2 ; ne[1 ] *= 2 ;
5542+ a = ggml_new_tensor (ctx, type, 4 , ne.data ());
5543+ ggml_set_name (a, " a" );
5544+
5545+ a = ggml_view_4d (ctx, a, ne_a[0 ], ne_a[1 ], ne_a[2 ], ne_a[3 ],
5546+ a->nb [1 ], a->nb [2 ], a->nb [3 ], 0 );
5547+ ggml_set_name (a, " view_of_a" );
5548+ } else {
5549+ a = ggml_new_tensor (ctx, type, 4 , ne_a.data ());
5550+ ggml_set_name (a, " a" );
5551+ }
5552+
5553+ ggml_tensor * out = ggml_win_part (ctx, a, w);
5554+ ggml_set_name (out, " out" );
5555+
5556+ return out;
5557+ }
5558+ };
5559+
5560+ // GGML_OP_WIN_UNPART
5561+ struct test_win_unpart : public test_case {
5562+ const ggml_type type;
5563+ const std::array<int64_t , 4 > ne_a; // [C, w, w, NPX*NPY*B]
5564+ const int w0; // output width
5565+ const int h0; // output height
5566+ const int w; // window size
5567+
5568+ std::string vars () override {
5569+ return VARS_TO_STR5 (type, ne_a, w0, h0, w);
5570+ }
5571+
5572+ test_win_unpart (ggml_type type = GGML_TYPE_F32,
5573+ std::array<int64_t , 4 > ne_a = {64 , 7 , 7 , 8 },
5574+ int w0 = 14 , int h0 = 14 ,
5575+ int w = 7 )
5576+ : type(type), ne_a(ne_a), w0(w0), h0(h0), w(w) {}
5577+
5578+ ggml_tensor * build_graph (ggml_context * ctx) override {
5579+ ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne_a.data ());
5580+ ggml_set_name (a, " a" );
5581+
5582+ ggml_tensor * out = ggml_win_unpart (ctx, a, w0, h0, w);
5583+ ggml_set_name (out, " out" );
5584+
5585+ return out;
5586+ }
5587+ };
5588+
5589+ // GGML_OP_GET_REL_POS
5590+ struct test_get_rel_pos : public test_case {
5591+ const ggml_type type;
5592+ const int C; // channels
5593+ const int qh; // query height
5594+ const int kh; // key height
5595+ const bool v; // view (non-contiguous input)
5596+
5597+ std::string vars () override {
5598+ return VARS_TO_STR5 (type, C, qh, kh, v);
5599+ }
5600+
5601+ test_get_rel_pos (ggml_type type = GGML_TYPE_F32,
5602+ int C = 64 ,
5603+ int qh = 7 ,
5604+ int kh = 7 ,
5605+ bool v = false )
5606+ : type(type), C(C), qh(qh), kh(kh), v(v) {}
5607+
5608+ ggml_tensor * build_graph (ggml_context * ctx) override {
5609+ // Input tensor has relative position embeddings table
5610+ // Shape: [C, 2*max(qh,kh)-1, 1, 1]
5611+ const int64_t ne_a[4 ] = {C, 2 *std::max (qh, kh) - 1 , 1 , 1 };
5612+
5613+ ggml_tensor * a;
5614+ if (v) {
5615+ // Create larger tensor and view into it (non-contiguous)
5616+ int64_t ne_large[4 ] = {C * 2 , 2 *std::max (qh, kh) - 1 , 1 , 1 };
5617+ a = ggml_new_tensor (ctx, type, 4 , ne_large);
5618+ ggml_set_name (a, " a" );
5619+
5620+ a = ggml_view_4d (ctx, a, C, 2 *std::max (qh, kh) - 1 , 1 , 1 ,
5621+ a->nb [1 ], a->nb [2 ], a->nb [3 ], 0 );
5622+ ggml_set_name (a, " view_of_a" );
5623+ } else {
5624+ a = ggml_new_tensor (ctx, type, 4 , ne_a);
5625+ ggml_set_name (a, " a" );
5626+ }
5627+
5628+ // Output shape: [C, kh, qh, 1]
5629+ ggml_tensor * out = ggml_get_rel_pos (ctx, a, qh, kh);
5630+ ggml_set_name (out, " out" );
5631+
5632+ return out;
5633+ }
5634+ };
5635+
55215636// GGML_OP_ROLL
55225637struct test_roll : public test_case {
55235638 const int shift0;
@@ -7565,6 +7680,53 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
75657680 test_cases.emplace_back (new test_pad_ext ());
75667681 test_cases.emplace_back (new test_pad_reflect_1d ());
75677682 test_cases.emplace_back (new test_pad_reflect_1d (GGML_TYPE_F32, {3000 , 384 , 4 , 1 }));
7683+
7684+ // Window partition tests
7685+ for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7686+ for (bool v : {false , true }) {
7687+ // Exact division: 14x14 -> 2x2 windows of 7x7
7688+ test_cases.emplace_back (new test_win_part (type, {64 , 14 , 14 , 2 }, 7 , v));
7689+ // With padding: 15x15 -> 3x3 windows of 7x7 (padded)
7690+ test_cases.emplace_back (new test_win_part (type, {64 , 15 , 15 , 2 }, 7 , v));
7691+ // Single window: 7x7 -> 1x1 windows of 7x7
7692+ test_cases.emplace_back (new test_win_part (type, {64 , 7 , 7 , 1 }, 7 , v));
7693+ // Larger: 28x28 -> 4x4 windows of 7x7
7694+ test_cases.emplace_back (new test_win_part (type, {128 , 28 , 28 , 4 }, 7 , v));
7695+ // Window size 8: 16x16 -> 2x2 windows of 8x8
7696+ test_cases.emplace_back (new test_win_part (type, {96 , 16 , 16 , 1 }, 8 , v));
7697+ }
7698+ }
7699+
7700+ // Window unpartition tests (inverse of partition)
7701+ for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7702+ // Exact division: 2x2 windows of 7x7 -> 14x14
7703+ test_cases.emplace_back (new test_win_unpart (type, {64 , 7 , 7 , 4 }, 14 , 14 , 7 ));
7704+ // With padding: 3x3 windows of 7x7 -> 15x15
7705+ test_cases.emplace_back (new test_win_unpart (type, {64 , 7 , 7 , 9 }, 15 , 15 , 7 ));
7706+ // Single window: 1x1 windows of 7x7 -> 7x7
7707+ test_cases.emplace_back (new test_win_unpart (type, {64 , 7 , 7 , 1 }, 7 , 7 , 7 ));
7708+ // Larger: 4x4 windows of 7x7 -> 28x28
7709+ test_cases.emplace_back (new test_win_unpart (type, {128 , 7 , 7 , 16 }, 28 , 28 , 7 ));
7710+ // Window size 8: 2x2 windows of 8x8 -> 16x16
7711+ test_cases.emplace_back (new test_win_unpart (type, {96 , 8 , 8 , 4 }, 16 , 16 , 8 ));
7712+ }
7713+
7714+ // Relative position embedding tests (used in SAM)
7715+ for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
7716+ for (bool v : {false , true }) {
7717+ // Square small: 3x3 attention
7718+ test_cases.emplace_back (new test_get_rel_pos (type, 5 , 3 , 3 , v));
7719+ // Square medium: 7x7 attention (typical SAM)
7720+ test_cases.emplace_back (new test_get_rel_pos (type, 13 , 7 , 7 , v));
7721+ // Square large: 14x14 attention
7722+ test_cases.emplace_back (new test_get_rel_pos (type, 27 , 14 , 14 , v));
7723+ // Rectangular: 14x7 attention
7724+ test_cases.emplace_back (new test_get_rel_pos (type, 27 , 14 , 7 , v));
7725+ // Edge case: 1x1 attention (minimum)
7726+ test_cases.emplace_back (new test_get_rel_pos (type, 1 , 1 , 1 , v));
7727+ }
7728+ }
7729+
75687730 test_cases.emplace_back (new test_roll ());
75697731 test_cases.emplace_back (new test_arange ());
75707732 test_cases.emplace_back (new test_timestep_embedding ());
0 commit comments