@@ -493,19 +493,29 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
493493typedef std::function<void (ggml_tensor*, ggml_tensor*, bool )> on_tile_process;
494494
495495// Tiling
496- __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
496+ __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true ) {
497497 int input_width = (int )input->ne [0 ];
498498 int input_height = (int )input->ne [1 ];
499499 int output_width = (int )output->ne [0 ];
500500 int output_height = (int )output->ne [1 ];
501+
502+ int input_tile_size, output_tile_size;
503+ if (scaled_out) {
504+ input_tile_size = tile_size;
505+ output_tile_size = tile_size * scale;
506+ } else {
507+ input_tile_size = tile_size * scale;
508+ output_tile_size = tile_size;
509+ }
510+
501511 GGML_ASSERT (input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0 ); // should be multiple of 2
502512
503- int tile_overlap = (int32_t )(tile_size * tile_overlap_factor);
504- int non_tile_overlap = tile_size - tile_overlap;
513+ int tile_overlap = (int32_t )(input_tile_size * tile_overlap_factor);
514+ int non_tile_overlap = input_tile_size - tile_overlap;
505515
506516 struct ggml_init_params params = {};
507- params.mem_size += tile_size * tile_size * input->ne [2 ] * sizeof (float ); // input chunk
508- params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne [2 ] * sizeof (float ); // output chunk
517+ params.mem_size += input_tile_size * input_tile_size * input->ne [2 ] * sizeof (float ); // input chunk
518+ params.mem_size += output_tile_size * output_tile_size * output->ne [2 ] * sizeof (float ); // output chunk
509519 params.mem_size += 3 * ggml_tensor_overhead ();
510520 params.mem_buffer = NULL ;
511521 params.no_alloc = false ;
@@ -520,8 +530,8 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
520530 }
521531
522532 // tiling
523- ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size, tile_size , input->ne [2 ], 1 );
524- ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale , output->ne [2 ], 1 );
533+ ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size , input->ne [2 ], 1 );
534+ ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size , output->ne [2 ], 1 );
525535 on_processing (input_tile, NULL , true );
526536 int num_tiles = ceil ((float )input_width / non_tile_overlap) * ceil ((float )input_height / non_tile_overlap);
527537 LOG_INFO (" processing %i tiles" , num_tiles);
@@ -530,19 +540,23 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
530540 bool last_y = false , last_x = false ;
531541 float last_time = 0 .0f ;
532542 for (int y = 0 ; y < input_height && !last_y; y += non_tile_overlap) {
533- if (y + tile_size >= input_height) {
534- y = input_height - tile_size ;
543+ if (y + input_tile_size >= input_height) {
544+ y = input_height - input_tile_size ;
535545 last_y = true ;
536546 }
537547 for (int x = 0 ; x < input_width && !last_x; x += non_tile_overlap) {
538- if (x + tile_size >= input_width) {
539- x = input_width - tile_size ;
548+ if (x + input_tile_size >= input_width) {
549+ x = input_width - input_tile_size ;
540550 last_x = true ;
541551 }
542552 int64_t t1 = ggml_time_ms ();
543553 ggml_split_tensor_2d (input, input_tile, x, y);
544554 on_processing (input_tile, output_tile, false );
545- ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
555+ if (scaled_out) {
556+ ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
557+ } else {
558+ ggml_merge_tensor_2d (output_tile, output, x / scale, y / scale, tile_overlap / scale);
559+ }
546560 int64_t t2 = ggml_time_ms ();
547561 last_time = (t2 - t1) / 1000 .0f ;
548562 pretty_progress (tile_count, num_tiles, last_time);
0 commit comments