Skip to content

Commit 22aa8a8

Browse files
nkpatel-ttign-febin
authored andcommitted
Execute untilize + RM Fold if tiled input tensor's channels are 32 multiple (tenstorrent#28379)
### Ticket tenstorrent#22378 ### What's changed For 32 aligned input channels, `untilize + RM Fold` perform better as compared to `TL Fold + reshape`. ### Checklist - [ ] [Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml) CI [running](https://github.com/tenstorrent/tt-metal/actions/runs/17667596363) - [ ] [Nightly tt-metal L2 tests](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml) CI [running](https://github.com/tenstorrent/tt-metal/actions/runs/17667587225) Signed-off-by: Nilaykumar Patel <[email protected]>
1 parent 7c218fc commit 22aa8a8

File tree

1 file changed

+8
-2
lines changed
  • ttnn/cpp/ttnn/operations/data_movement/fold

1 file changed

+8
-2
lines changed

ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,15 @@ Tensor FoldOperation::invoke(
328328
auto input_height = input_tensor.logical_shape()[1];
329329
auto input_width = input_tensor.logical_shape()[2];
330330
auto in_channels = input_tensor.logical_shape()[3];
331+
auto fold_input_tensor = input_tensor;
332+
if (in_channels % 32 == 0 && fold_input_tensor.layout() == Layout::TILE) {
333+
// Convert to row-major layout for 32-channel aligned tensors to leverage faster untilize+RM fold path
334+
fold_input_tensor = ttnn::to_layout(input_tensor, Layout::ROW_MAJOR);
335+
}
336+
331337
auto output_tensor =
332-
ttnn::prim::fold(queue_id, input_tensor, stride_h, stride_w, output_shape, pad_c, pad_h, pad_w);
333-
if (input_tensor.layout() == Layout::TILE) {
338+
ttnn::prim::fold(queue_id, fold_input_tensor, stride_h, stride_w, output_shape, pad_c, pad_h, pad_w);
339+
if (fold_input_tensor.layout() == Layout::TILE) {
334340
return ttnn::reshape(
335341
output_tensor,
336342
ttnn::Shape(

0 commit comments

Comments
 (0)