Skip to content

Commit 46ad202

Browse files
window partition, reverse OCAB fix, ssr pcc: 0.999952
1 parent 9e79d30 commit 46ad202

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

models/experimental/SSR/tests/test_ssr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def test_ssr_model(input_shape, num_cls, with_conv):
155155
)
156156

157157
# Convert input to TTNN tensor
158-
tt_input = ttnn.from_torch(x, device=device, layout=ttnn.TILE_LAYOUT)
158+
tt_input = ttnn.from_torch(x, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16)
159159

160160
# Run TTNN model
161161
tt_sr, tt_patch_fea3 = tt_model(tt_input)

models/experimental/SSR/tt/OCAB.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,35 @@
44

55

66
def window_partition_ttnn(x, window_size):
7-
"""Partition into non-overlapping windows"""
8-
B, H, W, C = x.shape
9-
num_windows = (H // window_size) * (W // window_size)
10-
return ttnn.reshape(x, [B * num_windows, window_size, window_size, C], memory_config=ttnn.L1_MEMORY_CONFIG)
7+
"""TTNN implementation of window partitioning"""
8+
b, h, w, c = x.shape
119

10+
# Reshape: (b, h, w, c) -> (b, h//ws, ws, w//ws, ws, c)
11+
reshaped = ttnn.reshape(x, (b, h // window_size, window_size, w // window_size, window_size, c))
1212

13-
def window_reverse_ttnn(windows, window_size, H, W):
14-
B = windows.shape[0] // (H * W // window_size // window_size)
15-
return ttnn.reshape(windows, [B, H, W, -1], memory_config=ttnn.L1_MEMORY_CONFIG)
13+
# Permute: (0, 1, 3, 2, 4, 5) -> group windows together
14+
permuted = ttnn.permute(reshaped, (0, 1, 3, 2, 4, 5))
15+
16+
# Final reshape to get windows
17+
windows = ttnn.reshape(permuted, (-1, window_size, window_size, c))
18+
19+
return windows
20+
21+
22+
def window_reverse_ttnn(windows, window_size, h, w):
23+
"""TTNN implementation of window reverse"""
24+
b = int(windows.shape[0] / (h * w / window_size / window_size))
25+
26+
# Reshape windows back to grid
27+
reshaped = ttnn.reshape(windows, (b, h // window_size, w // window_size, window_size, window_size, -1))
28+
29+
# Permute back to original order
30+
permuted = ttnn.permute(reshaped, (0, 1, 3, 2, 4, 5))
31+
32+
# Final reshape to original spatial dimensions
33+
output = ttnn.reshape(permuted, (b, h, w, -1))
34+
35+
return output
1636

1737

1838
class TTOCAB(LightweightModule):

0 commit comments

Comments
 (0)