Skip to content

Commit 4b10242

Browse files
skrsticTTign-febin
authored andcommitted
Fast tilize in tiled out case for pool ops (tenstorrent#28382)
Use `fast_tilize_*` instead od `tilize_*` API for tilizing pool's output (better perf). Currently it does not support bfp4_b output so it is handled with classic tilize, opened issue tenstorrent#28380. ### Checklist - [x] [All post commit](https://github.com/tenstorrent/tt-metal/actions/runs/17725852995) - [x] [Nightly tt-metal L2 tests](https://github.com/tenstorrent/tt-metal/actions/runs/17725837243) (vae sdxl already regressed, not related to this change)
1 parent f385280 commit 4b10242

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

ttnn/cpp/ttnn/operations/pool/generic/device/kernels/compute/compute_pool_2d.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ void MAIN {
5555
constexpr uint32_t pre_tilize_cb_id = get_compile_time_arg_val(19);
5656
constexpr bool is_output_tiled = get_compile_time_arg_val(20); // 1 = TILED, 0 = ROW_MAJOR
5757
constexpr bool is_output_block_format = (bool)get_compile_time_arg_val(21);
58-
58+
constexpr bool is_output_bfp4_b = (bool)get_compile_time_arg_val(22);
5959

6060
constexpr uint32_t topk_output_tiles = 1;
6161
constexpr uint32_t topk_cb_tile_idx = 0;
@@ -228,11 +228,17 @@ void MAIN {
228228
unary_op_init_common(pre_tilize_cb_id, out_cb_id);
229229
tensix_sync();
230230

231-
tilize_init(pre_tilize_cb_id, in_ntiles_c, out_cb_id);
232-
tilize_block(pre_tilize_cb_id, in_ntiles_c, out_cb_id);
233-
231+
// Skip fast_tilize path for bfp4_b output until #28380 is closed
232+
if constexpr (is_output_bfp4_b) {
233+
tilize_init(pre_tilize_cb_id, in_ntiles_c, out_cb_id);
234+
tilize_block(pre_tilize_cb_id, in_ntiles_c, out_cb_id);
235+
tilize_uninit(pre_tilize_cb_id, out_cb_id);
236+
} else {
237+
fast_tilize_init(pre_tilize_cb_id, in_ntiles_c, out_cb_id);
238+
fast_tilize_block(pre_tilize_cb_id, in_ntiles_c, out_cb_id);
239+
fast_tilize_uninit(pre_tilize_cb_id, out_cb_id);
240+
}
234241
cb_push_back(out_cb_id, in_ntiles_c);
235-
tilize_uninit(pre_tilize_cb_id, out_cb_id);
236242

237243
if constexpr (is_output_block_format) {
238244
tensix_sync();

ttnn/cpp/ttnn/operations/pool/generic/device/pool_multi_core_program_factory.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ Pool2D::MultiCore::cached_program_t pool2d_multi_core_sharded_with_halo_v2_impl_
418418

419419
const bool is_output_tiled = output_layout == Layout::TILE;
420420
const bool is_output_block_format = is_block_float(outputs[0].dtype());
421+
const bool is_output_bfp4_b = outputs[0].dtype() == DataType::BFLOAT4_B;
421422

422423
// Conditionally allocate temporary CB - only needed for TILED output
423424
uint32_t pre_tilize_cb_id = 32; // default invalid CB ID
@@ -618,7 +619,8 @@ Pool2D::MultiCore::cached_program_t pool2d_multi_core_sharded_with_halo_v2_impl_
618619
(uint32_t)return_indices, // 18
619620
pre_tilize_cb_id, // 19
620621
is_output_tiled, // 20
621-
is_output_block_format}; // 21
622+
is_output_block_format, // 21
623+
is_output_bfp4_b}; // 22
622624

623625
auto compute_config = tt::tt_metal::ComputeConfig{
624626
.math_fidelity = MathFidelity::HiFi4,

ttnn/cpp/ttnn/operations/pool/grid_sample/device/grid_sample_program_factory.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ tt::tt_metal::operation::ProgramWithCallbacks grid_sample_program_factory(
155155

156156
const bool is_output_tiled = false;
157157
const bool is_output_block_format = false;
158+
const bool is_output_bfp4_b = false;
158159
const uint32_t pre_tilize_cb_id =
159160
32; // Unused CB for pool compute kernel for grid sample, we don't have tiled output in gridsample
160161

@@ -182,7 +183,8 @@ tt::tt_metal::operation::ProgramWithCallbacks grid_sample_program_factory(
182183
false, // 18: Return Indices (unused)
183184
pre_tilize_cb_id, // 19: Pre-tilize CB (unused)
184185
is_output_tiled, // 20: is_output_tiled (unused)
185-
is_output_block_format // 21: is_output_block_format (unused)
186+
is_output_block_format, // 21: is_output_block_format (unused)
187+
is_output_bfp4_b // 22: is_output_bfp4_b (unused)
186188
};
187189

188190
compute_kernel_group_1 = tt::tt_metal::CreateKernel(
@@ -221,7 +223,8 @@ tt::tt_metal::operation::ProgramWithCallbacks grid_sample_program_factory(
221223
false, // 18: Return Indices (unused)
222224
pre_tilize_cb_id, // 19: Pre-tilize CB (unused)
223225
is_output_tiled, // 20: is_output_tiled (unused)
224-
is_output_block_format // 21: is_output_block_format (unused)
226+
is_output_block_format, // 21: is_output_block_format (unused)
227+
is_output_bfp4_b // 22: is_output_bfp4_b (unused)
225228
};
226229

227230
compute_kernel_group_2 = tt::tt_metal::CreateKernel(

0 commit comments

Comments
 (0)