@@ -491,19 +491,30 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
491491typedef std::function<void (ggml_tensor*, ggml_tensor*, bool )> on_tile_process;
492492
493493// Tiling
494- __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
494+ __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true ) {
495495 int input_width = (int )input->ne [0 ];
496496 int input_height = (int )input->ne [1 ];
497497 int output_width = (int )output->ne [0 ];
498498 int output_height = (int )output->ne [1 ];
499+
500+ int input_tile_size, output_tile_size;
501+ if ( scaled_out ){
502+ input_tile_size = tile_size;
503+ output_tile_size = tile_size * scale;
504+ } else {
505+ input_tile_size = tile_size * scale;
506+ output_tile_size = tile_size;
507+ }
508+
509+
499510 GGML_ASSERT (input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0 ); // should be multiple of 2
500511
501- int tile_overlap = (int32_t )(tile_size * tile_overlap_factor);
502- int non_tile_overlap = tile_size - tile_overlap;
512+ int tile_overlap = (int32_t )(input_tile_size * tile_overlap_factor);
513+ int non_tile_overlap = input_tile_size - tile_overlap;
503514
504515 struct ggml_init_params params = {};
505- params.mem_size += tile_size * tile_size * input->ne [2 ] * sizeof (float ); // input chunk
506- params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne [2 ] * sizeof (float ); // output chunk
516+ params.mem_size += input_tile_size * input_tile_size * input->ne [2 ] * sizeof (float ); // input chunk
517+ params.mem_size += output_tile_size * output_tile_size * output->ne [2 ] * sizeof (float ); // output chunk
507518 params.mem_size += 3 * ggml_tensor_overhead ();
508519 params.mem_buffer = NULL ;
509520 params.no_alloc = false ;
@@ -518,8 +529,9 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
518529 }
519530
520531 // tiling
521- ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne [2 ], 1 );
522- ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne [2 ], 1 );
532+ ggml_tensor *input_tile, *output_tile;
533+ input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size, input->ne [2 ], 1 );
534+ output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size, output->ne [2 ], 1 );
523535 on_processing (input_tile, NULL , true );
524536 int num_tiles = ceil ((float )input_width / non_tile_overlap) * ceil ((float )input_height / non_tile_overlap);
525537 LOG_INFO (" processing %i tiles" , num_tiles);
@@ -528,19 +540,23 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
528540 bool last_y = false , last_x = false ;
529541 float last_time = 0 .0f ;
530542 for (int y = 0 ; y < input_height && !last_y; y += non_tile_overlap) {
531- if (y + tile_size >= input_height) {
532- y = input_height - tile_size ;
543+ if (y + input_tile_size >= input_height) {
544+ y = input_height - input_tile_size ;
533545 last_y = true ;
534546 }
535547 for (int x = 0 ; x < input_width && !last_x; x += non_tile_overlap) {
536- if (x + tile_size >= input_width) {
537- x = input_width - tile_size ;
548+ if (x + input_tile_size >= input_width) {
549+ x = input_width - input_tile_size ;
538550 last_x = true ;
539551 }
540552 int64_t t1 = ggml_time_ms ();
541553 ggml_split_tensor_2d (input, input_tile, x, y);
542554 on_processing (input_tile, output_tile, false );
543- ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
555+ if (scaled_out){
556+ ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
557+ } else {
558+ ggml_merge_tensor_2d (output_tile, output, x / scale, y / scale, tile_overlap / scale);
559+ }
544560 int64_t t2 = ggml_time_ms ();
545561 last_time = (t2 - t1) / 1000 .0f ;
546562 pretty_progress (tile_count, num_tiles, last_time);
@@ -673,13 +689,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
673689#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
674690 struct ggml_tensor * kqv = ggml_flash_attn (ctx, q, k, v, false ); // [N * n_head, n_token, d_head]
675691#else
676- float d_head = (float )q->ne [0 ];
692+ float d_head = (float )q->ne [0 ];
677693 struct ggml_tensor * kq = ggml_mul_mat (ctx, k, q); // [N * n_head, n_token, n_k]
678694 kq = ggml_scale_inplace (ctx, kq, 1 .0f / sqrt (d_head));
679695 if (mask) {
680696 kq = ggml_diag_mask_inf_inplace (ctx, kq, 0 );
681697 }
682- kq = ggml_soft_max_inplace (ctx, kq);
698+ kq = ggml_soft_max_inplace (ctx, kq);
683699 struct ggml_tensor * kqv = ggml_mul_mat (ctx, v, kq); // [N * n_head, n_token, d_head]
684700#endif
685701 return kqv;
0 commit comments