|
4 | 4 |
|
5 | 5 |
|
6 | 6 | def window_partition_ttnn(x, window_size): |
7 | | - """Partition into non-overlapping windows""" |
8 | | - B, H, W, C = x.shape |
9 | | - num_windows = (H // window_size) * (W // window_size) |
10 | | - return ttnn.reshape(x, [B * num_windows, window_size, window_size, C], memory_config=ttnn.L1_MEMORY_CONFIG) |
| 7 | + """TTNN implementation of window partitioning""" |
| 8 | + b, h, w, c = x.shape |
11 | 9 |
|
| 10 | + # Reshape: (b, h, w, c) -> (b, h//ws, ws, w//ws, ws, c) |
| 11 | + reshaped = ttnn.reshape(x, (b, h // window_size, window_size, w // window_size, window_size, c)) |
12 | 12 |
|
13 | | -def window_reverse_ttnn(windows, window_size, H, W): |
14 | | - B = windows.shape[0] // (H * W // window_size // window_size) |
15 | | - return ttnn.reshape(windows, [B, H, W, -1], memory_config=ttnn.L1_MEMORY_CONFIG) |
| 13 | + # Permute: (0, 1, 3, 2, 4, 5) -> group windows together |
| 14 | + permuted = ttnn.permute(reshaped, (0, 1, 3, 2, 4, 5)) |
| 15 | + |
| 16 | + # Final reshape to get windows |
| 17 | + windows = ttnn.reshape(permuted, (-1, window_size, window_size, c)) |
| 18 | + |
| 19 | + return windows |
| 20 | + |
| 21 | + |
| 22 | +def window_reverse_ttnn(windows, window_size, h, w): |
| 23 | + """TTNN implementation of window reverse""" |
| 24 | + b = int(windows.shape[0] / (h * w / window_size / window_size)) |
| 25 | + |
| 26 | + # Reshape windows back to grid |
| 27 | + reshaped = ttnn.reshape(windows, (b, h // window_size, w // window_size, window_size, window_size, -1)) |
| 28 | + |
| 29 | + # Permute back to original order |
| 30 | + permuted = ttnn.permute(reshaped, (0, 1, 3, 2, 4, 5)) |
| 31 | + |
| 32 | + # Final reshape to original spatial dimensions |
| 33 | + output = ttnn.reshape(permuted, (b, h, w, -1)) |
| 34 | + |
| 35 | + return output |
16 | 36 |
|
17 | 37 |
|
18 | 38 | class TTOCAB(LightweightModule): |
|
0 commit comments