@@ -529,19 +529,29 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
529529typedef std::function<void (ggml_tensor*, ggml_tensor*, bool )> on_tile_process;
530530
531531// Tiling
532- __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
532+ __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true ) {
533533 int input_width = (int )input->ne [0 ];
534534 int input_height = (int )input->ne [1 ];
535535 int output_width = (int )output->ne [0 ];
536536 int output_height = (int )output->ne [1 ];
537+
538+ int input_tile_size, output_tile_size;
539+ if (scaled_out) {
540+ input_tile_size = tile_size;
541+ output_tile_size = tile_size * scale;
542+ } else {
543+ input_tile_size = tile_size * scale;
544+ output_tile_size = tile_size;
545+ }
546+
537547 GGML_ASSERT (input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0 ); // should be multiple of 2
538548
539- int tile_overlap = (int32_t )(tile_size * tile_overlap_factor);
540- int non_tile_overlap = tile_size - tile_overlap;
549+ int tile_overlap = (int32_t )(input_tile_size * tile_overlap_factor);
550+ int non_tile_overlap = input_tile_size - tile_overlap;
541551
542552 struct ggml_init_params params = {};
543- params.mem_size += tile_size * tile_size * input->ne [2 ] * sizeof (float ); // input chunk
544- params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne [2 ] * sizeof (float ); // output chunk
553+ params.mem_size += input_tile_size * input_tile_size * input->ne [2 ] * sizeof (float ); // input chunk
554+ params.mem_size += output_tile_size * output_tile_size * output->ne [2 ] * sizeof (float ); // output chunk
545555 params.mem_size += 3 * ggml_tensor_overhead ();
546556 params.mem_buffer = NULL ;
547557 params.no_alloc = false ;
@@ -556,8 +566,8 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
556566 }
557567
558568 // tiling
559- ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size, tile_size , input->ne [2 ], 1 );
560- ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale , output->ne [2 ], 1 );
569+ ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size , input->ne [2 ], 1 );
570+ ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size , output->ne [2 ], 1 );
561571 on_processing (input_tile, NULL , true );
562572 int num_tiles = ceil ((float )input_width / non_tile_overlap) * ceil ((float )input_height / non_tile_overlap);
563573 LOG_INFO (" processing %i tiles" , num_tiles);
@@ -566,19 +576,23 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
566576 bool last_y = false , last_x = false ;
567577 float last_time = 0 .0f ;
568578 for (int y = 0 ; y < input_height && !last_y; y += non_tile_overlap) {
569- if (y + tile_size >= input_height) {
570- y = input_height - tile_size ;
579+ if (y + input_tile_size >= input_height) {
580+ y = input_height - input_tile_size ;
571581 last_y = true ;
572582 }
573583 for (int x = 0 ; x < input_width && !last_x; x += non_tile_overlap) {
574- if (x + tile_size >= input_width) {
575- x = input_width - tile_size ;
584+ if (x + input_tile_size >= input_width) {
585+ x = input_width - input_tile_size ;
576586 last_x = true ;
577587 }
578588 int64_t t1 = ggml_time_ms ();
579589 ggml_split_tensor_2d (input, input_tile, x, y);
580590 on_processing (input_tile, output_tile, false );
581- ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
591+ if (scaled_out) {
592+ ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
593+ } else {
594+ ggml_merge_tensor_2d (output_tile, output, x / scale, y / scale, tile_overlap / scale);
595+ }
582596 int64_t t2 = ggml_time_ms ();
583597 last_time = (t2 - t1) / 1000 .0f ;
584598 pretty_progress (tile_count, num_tiles, last_time);
0 commit comments