@@ -103,6 +103,64 @@ namespace DiT {
103103 x = ggml_ext_slice (ctx, x, 0 , 0 , W); // [N, C, H, W]
104104 return x;
105105 }
106+
107+ inline ggml_tensor* patchify (ggml_context* ctx,
108+ ggml_tensor* x,
109+ int pt,
110+ int ph,
111+ int pw,
112+ int64_t N = 1 ) {
113+ // x: [N*C, T, H, W]
114+ // return: [N, h*w, C*pt*ph*pw]
115+ int64_t C = x->ne [3 ] / N;
116+ int64_t T = x->ne [2 ];
117+ int64_t H = x->ne [1 ];
118+ int64_t W = x->ne [0 ];
119+ int64_t t_len = T / pt;
120+ int64_t h_len = H / ph;
121+ int64_t w_len = W / pw;
122+
123+ GGML_ASSERT (C * N == x->ne [3 ]);
124+ GGML_ASSERT (t_len * pt == T && h_len * ph == H && w_len * pw == W);
125+
126+ x = ggml_reshape_4d (ctx, x, pw * w_len, ph * h_len, pt, t_len * C * N); // [N*C*t_len, pt, h_len*ph, w_len*pw]
127+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [N*C*t_len, h_len*ph, pt, w_len*pw]
128+ x = ggml_reshape_4d (ctx, x, pw * w_len, pt, ph, h_len * t_len * C * N); // [N*C*t_len*h_len, ph, pt, w_len*pw]
129+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [N*C*t_len*h_len, pt, ph, w_len*pw]
130+ x = ggml_reshape_4d (ctx, x, pw, w_len, ph * pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt*ph, w_len, pw]
131+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [N*C*t_len*h_len, w_len, pt*ph, pw]
132+ x = ggml_reshape_4d (ctx, x, pw * ph * pt, w_len * h_len * t_len, C, N); // [N, C, t_len*h_len*w_len, pt*ph*pw]
133+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [N, t_len*h_len*w_len, C, pt*ph*pw]
134+ x = ggml_reshape_4d (ctx, x, pw * ph * pt * C, w_len * h_len * t_len, N, 1 ); // [N, t_len*h_len*w_len, C*pt*ph*pw]
135+ return x;
136+ }
137+
138+ inline ggml_tensor* unpatchify (ggml_context* ctx,
139+ ggml_tensor* x,
140+ int64_t t_len,
141+ int64_t h_len,
142+ int64_t w_len,
143+ int pt,
144+ int ph,
145+ int pw) {
146+ // x: [N, t_len*h_len*w_len, pt*ph*pw*C]
147+ // return: [N*C, t_len*pt, h_len*ph, w_len*pw]
148+ int64_t N = x->ne [3 ];
149+ int64_t C = x->ne [0 ] / pt / ph / pw;
150+
151+ GGML_ASSERT (C * pt * ph * pw == x->ne [0 ]);
152+
153+ x = ggml_reshape_4d (ctx, x, C, pw * ph * pt, w_len * h_len * t_len, N); // [N, t_len*h_len*w_len, pt*ph*pw, C]
154+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 1 , 2 , 0 , 3 )); // [N, C, t_len*h_len*w_len, pt*ph*pw]
155+ x = ggml_reshape_4d (ctx, x, pw, ph * pt, w_len, h_len * t_len * C * N); // [N*C*t_len*h_len, w_len, pt*ph, pw]
156+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [N*C*t_len*h_len, pt*ph, w_len, pw]
157+ x = ggml_reshape_4d (ctx, x, pw * w_len, ph, pt, h_len * t_len * C * N); // [N*C*t_len*h_len, pt, ph, w_len*pw]
158+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [N*C*t_len*h_len, ph, pt, w_len*pw]
159+ x = ggml_reshape_4d (ctx, x, pw * w_len, pt, ph * h_len, t_len * C * N); // [N*C*t_len, h_len*ph, pt, w_len*pw]
160+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [N*C*t_len, pt, h_len*ph, w_len*pw]
161+ x = ggml_reshape_4d (ctx, x, pw * w_len, ph * h_len, pt * t_len, C * N); // [N*C, t_len*pt, h_len*ph, w_len*pw]
162+ return x;
163+ }
106164} // namespace DiT
107165
108166#endif // __COMMON_DIT_HPP__
0 commit comments