@@ -1376,14 +1376,24 @@ class StableDiffusionGGML {
13761376 ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
13771377 int64_t t0 = ggml_time_ms ();
13781378 ggml_tensor* result = NULL ;
1379+ int W = x->ne [0 ] / 8 ;
1380+ int H = x->ne [1 ] / 8 ;
1381+ if (vae_tiling && !decode_video) {
1382+ // TODO wan2.2 vae support?
1383+ int C = sd_version_is_dit (version) ? 16 : 4 ;
1384+ if (!use_tiny_autoencoder) {
1385+ C *= 2 ;
1386+ }
1387+ result = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, W, H, C, x->ne [3 ]);
1388+ }
13791389 // TODO: args instead of env for tile size / overlap?
13801390 if (!use_tiny_autoencoder) {
13811391 float tile_overlap = 0 .5f ;
13821392 int tile_size_x = 32 ;
13831393 int tile_size_y = 32 ;
13841394
13851395 get_vae_tile_overlap (tile_overlap);
1386- get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x-> ne [ 0 ] / 8 , x-> ne [ 1 ] / 8 );
1396+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, W, H );
13871397
13881398 // TODO: also use an arg for this one?
13891399 // multiply tile size for encode to keep the compute buffer size consistent
@@ -1393,7 +1403,7 @@ class StableDiffusionGGML {
13931403 process_vae_input_tensor (x);
13941404 if (vae_tiling && !decode_video) {
13951405 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1396- first_stage_model->compute (n_threads, in, true , &out, NULL );
1406+ first_stage_model->compute (n_threads, in, false , &out, work_ctx );
13971407 };
13981408 sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap, on_tiling);
13991409 } else {
@@ -1404,7 +1414,7 @@ class StableDiffusionGGML {
14041414 if (vae_tiling && !decode_video) {
14051415 // split latent in 32x32 tiles and compute in several steps
14061416 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1407- tae_first_stage->compute (n_threads, in, true , &out, NULL );
1417+ tae_first_stage->compute (n_threads, in, false , &out, NULL );
14081418 };
14091419 sd_tiling (x, result, 8 , 64 , 0 .5f , on_tiling);
14101420 } else {
0 commit comments