|
8 | 8 | from models.experimental.SSR.tt.tile_refinement.tile_refinement import TTTileRefinement |
9 | 9 | from models.experimental.SSR.tt.tile_selection.tile_selection import TTTileSelection |
10 | 10 | from models.experimental.SSR.tt.tile_refinement.upsample import TTUpsample |
11 | | -from models.experimental.SSR.tt.tile_refinement.RHAG.ATTEN_BLK.OCAB.OCAB import ( |
12 | | - window_partition_ttnn, |
13 | | - window_reverse_ttnn, |
14 | | -) |
| 11 | + |
| 12 | + |
| 13 | +def window_partition_ttnn(x, window_size): |
| 14 | + """TTNN implementation of window partitioning""" |
| 15 | + b, h, w, c = x.shape |
| 16 | + |
| 17 | + # Reshape: (b, h, w, c) -> (b, h//ws, ws, w//ws, ws, c) |
| 18 | + x = ttnn.reshape( |
| 19 | + x, (b, h // window_size, window_size, w // window_size, window_size, c), memory_config=ttnn.L1_MEMORY_CONFIG |
| 20 | + ) |
| 21 | + |
| 22 | + # Permute: (0, 1, 3, 2, 4, 5) -> group windows together |
| 23 | + x = ttnn.permute(x, (0, 1, 3, 2, 4, 5), memory_config=ttnn.L1_MEMORY_CONFIG) |
| 24 | + |
| 25 | + # Final reshape to get windows |
| 26 | + x = ttnn.reshape(x, (-1, window_size, window_size, c), memory_config=ttnn.L1_MEMORY_CONFIG) |
| 27 | + |
| 28 | + return x |
| 29 | + |
| 30 | + |
| 31 | +def window_reverse_ttnn(windows, window_size, h, w): |
| 32 | + """TTNN implementation of window reverse""" |
| 33 | + b = int(windows.shape[0] / (h * w / window_size / window_size)) |
| 34 | + |
| 35 | + # Reshape windows back to grid |
| 36 | + windows = ttnn.reshape( |
| 37 | + windows, |
| 38 | + (b, h // window_size, w // window_size, window_size, window_size, -1), |
| 39 | + memory_config=ttnn.L1_MEMORY_CONFIG, |
| 40 | + ) |
| 41 | + |
| 42 | + # Permute back to original order |
| 43 | + windows = ttnn.permute(windows, (0, 1, 3, 2, 4, 5), memory_config=ttnn.L1_MEMORY_CONFIG) |
| 44 | + |
| 45 | + # Final reshape to original spatial dimensions |
| 46 | + windows = ttnn.reshape(windows, (b, h, w, -1), memory_config=ttnn.L1_MEMORY_CONFIG) |
| 47 | + |
| 48 | + return windows |
15 | 49 |
|
16 | 50 |
|
17 | 51 | class TTSSR(LightweightModule): |
@@ -283,6 +317,8 @@ def forward(self, x): |
283 | 317 | negX = ttnn.to_layout(negX, ttnn.TILE_LAYOUT) |
284 | 318 | sr_patches.append(negX) |
285 | 319 |
|
| 320 | + ttnn.deallocate(patch_x) |
| 321 | + |
286 | 322 | # Concatenate and reconstruct |
287 | 323 | sr = ttnn.concat(sr_patches, dim=0) |
288 | 324 | sr = window_reverse_ttnn(sr, window_size=H, h=H * 4, w=W * 4) |
|
0 commit comments