Skip to content

Commit cf3dc21

Browse files
committed
Window partition and window reverse are now in SSR & uses L1 memory config
1 parent 6597d01 commit cf3dc21

File tree

2 files changed

+40
-36
lines changed
  • models/experimental/SSR/tt

2 files changed

+40
-36
lines changed

models/experimental/SSR/tt/ssr.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,44 @@
88
from models.experimental.SSR.tt.tile_refinement.tile_refinement import TTTileRefinement
99
from models.experimental.SSR.tt.tile_selection.tile_selection import TTTileSelection
1010
from models.experimental.SSR.tt.tile_refinement.upsample import TTUpsample
11-
from models.experimental.SSR.tt.tile_refinement.RHAG.ATTEN_BLK.OCAB.OCAB import (
12-
window_partition_ttnn,
13-
window_reverse_ttnn,
14-
)
11+
12+
13+
def window_partition_ttnn(x, window_size):
14+
"""TTNN implementation of window partitioning"""
15+
b, h, w, c = x.shape
16+
17+
# Reshape: (b, h, w, c) -> (b, h//ws, ws, w//ws, ws, c)
18+
x = ttnn.reshape(
19+
x, (b, h // window_size, window_size, w // window_size, window_size, c), memory_config=ttnn.L1_MEMORY_CONFIG
20+
)
21+
22+
# Permute: (0, 1, 3, 2, 4, 5) -> group windows together
23+
x = ttnn.permute(x, (0, 1, 3, 2, 4, 5), memory_config=ttnn.L1_MEMORY_CONFIG)
24+
25+
# Final reshape to get windows
26+
x = ttnn.reshape(x, (-1, window_size, window_size, c), memory_config=ttnn.L1_MEMORY_CONFIG)
27+
28+
return x
29+
30+
31+
def window_reverse_ttnn(windows, window_size, h, w):
32+
"""TTNN implementation of window reverse"""
33+
b = int(windows.shape[0] / (h * w / window_size / window_size))
34+
35+
# Reshape windows back to grid
36+
windows = ttnn.reshape(
37+
windows,
38+
(b, h // window_size, w // window_size, window_size, window_size, -1),
39+
memory_config=ttnn.L1_MEMORY_CONFIG,
40+
)
41+
42+
# Permute back to original order
43+
windows = ttnn.permute(windows, (0, 1, 3, 2, 4, 5), memory_config=ttnn.L1_MEMORY_CONFIG)
44+
45+
# Final reshape to original spatial dimensions
46+
windows = ttnn.reshape(windows, (b, h, w, -1), memory_config=ttnn.L1_MEMORY_CONFIG)
47+
48+
return windows
1549

1650

1751
class TTSSR(LightweightModule):
@@ -283,6 +317,8 @@ def forward(self, x):
283317
negX = ttnn.to_layout(negX, ttnn.TILE_LAYOUT)
284318
sr_patches.append(negX)
285319

320+
ttnn.deallocate(patch_x)
321+
286322
# Concatenate and reconstruct
287323
sr = ttnn.concat(sr_patches, dim=0)
288324
sr = window_reverse_ttnn(sr, window_size=H, h=H * 4, w=W * 4)

models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/OCAB/OCAB.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,6 @@
66
from models.common.lightweightmodule import LightweightModule
77

88

9-
def window_partition_ttnn(x, window_size):
10-
"""TTNN implementation of window partitioning"""
11-
b, h, w, c = x.shape
12-
13-
# Reshape: (b, h, w, c) -> (b, h//ws, ws, w//ws, ws, c)
14-
reshaped = ttnn.reshape(x, (b, h // window_size, window_size, w // window_size, window_size, c))
15-
16-
# Permute: (0, 1, 3, 2, 4, 5) -> group windows together
17-
permuted = ttnn.permute(reshaped, (0, 1, 3, 2, 4, 5))
18-
19-
# Final reshape to get windows
20-
windows = ttnn.reshape(permuted, (-1, window_size, window_size, c))
21-
22-
return windows
23-
24-
25-
def window_reverse_ttnn(windows, window_size, h, w):
26-
"""TTNN implementation of window reverse"""
27-
b = int(windows.shape[0] / (h * w / window_size / window_size))
28-
29-
# Reshape windows back to grid
30-
reshaped = ttnn.reshape(windows, (b, h // window_size, w // window_size, window_size, window_size, -1))
31-
32-
# Permute back to original order
33-
permuted = ttnn.permute(reshaped, (0, 1, 3, 2, 4, 5))
34-
35-
# Final reshape to original spatial dimensions
36-
output = ttnn.reshape(permuted, (b, h, w, -1))
37-
38-
return output
39-
40-
419
class TTOCAB(LightweightModule):
4210
def __init__(
4311
self,

0 commit comments

Comments
 (0)