Skip to content

Commit 71d4b93

Browse files
test_ssr fix
1 parent 079bcf6 commit 71d4b93

File tree

3 files changed

+11
-16
lines changed

3 files changed

+11
-16
lines changed

models/experimental/SSR/tests/test_ssr.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def custom_preprocessor(torch_model, name, ttnn_module_args):
5252
forward_params = {"rpi_sa": tt_rpi_sa, "attn_mask": attn_mask, "rpi_oca": tt_rpi_oca}
5353
sr_params = preprocess_model_parameters(
5454
initialize_model=lambda: torch_model.sr_model,
55-
custom_preprocessor=create_tile_refinement_preprocessor(device, forward_params),
55+
custom_preprocessor=create_tile_refinement_preprocessor(
56+
device, forward_params, window_size=16, rpi_sa=rpi_sa
57+
),
5658
device=device,
5759
)
5860
parameters["sr_model"] = sr_params
@@ -104,7 +106,7 @@ def __init__(self):
104106
@pytest.mark.parametrize(
105107
"input_shape, num_cls, with_conv",
106108
[
107-
((1, 3, 256, 256), 1, True),
109+
# ((1, 3, 256, 256), 1, True),
108110
((1, 3, 256, 256), 1, False),
109111
],
110112
)
@@ -156,27 +158,21 @@ def test_ssr_model(input_shape, num_cls, with_conv):
156158
tt_input = ttnn.from_torch(x, device=device, layout=ttnn.TILE_LAYOUT)
157159

158160
# Run TTNN model
159-
tt_sr, tt_patch_fea3, tt_patch_fea2, tt_patch_fea1 = tt_model(tt_input)
161+
tt_sr, tt_patch_fea3 = tt_model(tt_input)
160162

161163
# Convert back to torch tensors
162164
tt_torch_sr = tt2torch_tensor(tt_sr)
163165
tt_torch_patch_fea3 = tt2torch_tensor(tt_patch_fea3)
164-
tt_torch_patch_fea2 = tt2torch_tensor(tt_patch_fea2)
165-
tt_torch_patch_fea1 = tt2torch_tensor(tt_patch_fea1)
166166
tt_torch_sr = tt_torch_sr.permute(0, 3, 1, 2)
167167

168168
# Compare outputs
169169
sr_pass, sr_pcc_message = comp_pcc(ref_sr, tt_torch_sr, 0.95)
170170
fea3_pass, fea3_pcc_message = comp_pcc(ref_patch_fea3, tt_torch_patch_fea3, 0.95)
171-
fea2_pass, fea2_pcc_message = comp_pcc(ref_patch_fea2, tt_torch_patch_fea2, 0.95)
172-
fea1_pass, fea1_pcc_message = comp_pcc(ref_patch_fea1, tt_torch_patch_fea1, 0.95)
173171

174172
logger.info(f"SR Output PCC: {sr_pcc_message}")
175173
logger.info(f"Patch Fea3 PCC: {fea3_pcc_message}")
176-
logger.info(f"Patch Fea2 PCC: {fea2_pcc_message}")
177-
logger.info(f"Patch Fea1 PCC: {fea1_pcc_message}")
178174

179-
all_pass = sr_pass and fea3_pass and fea2_pass and fea1_pass
175+
all_pass = sr_pass and fea3_pass
180176

181177
if all_pass:
182178
logger.info("TTSSR Test Passed!")
@@ -185,8 +181,6 @@ def test_ssr_model(input_shape, num_cls, with_conv):
185181

186182
assert sr_pass, f"SR output comparison failed: {sr_pcc_message}"
187183
assert fea3_pass, f"Patch fea3 comparison failed: {fea3_pcc_message}"
188-
assert fea2_pass, f"Patch fea2 comparison failed: {fea2_pcc_message}"
189-
assert fea1_pass, f"Patch fea1 comparison failed: {fea1_pcc_message}"
190184

191185
finally:
192186
ttnn.close_device(device)

models/experimental/SSR/tt/ssr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def forward(self, x):
7979
B, C, H, W = x.shape
8080

8181
# Get tile selection features
82-
patch_fea3, patch_fea2, patch_fea1 = self.select_model(x)
82+
patch_fea3 = self.select_model(x)
8383

8484
# Calculate selection threshold (top 25%)
8585
patch_fea3_flat = ttnn.reshape(patch_fea3, (-1,))
@@ -199,7 +199,7 @@ def forward(self, x):
199199

200200
sr = ttnn.reshape(sr, [B, 1024, 1024, 3]) # TODO
201201

202-
return sr, patch_fea3, patch_fea2, patch_fea1
202+
return sr, patch_fea3
203203

204204

205205
class TTSSR_wo_conv(LightweightModule):
@@ -237,7 +237,7 @@ def forward(self, x):
237237
B, C, H, W = x.shape
238238

239239
# Same tile selection logic
240-
patch_fea3, patch_fea2, patch_fea1 = self.select_model(x)
240+
patch_fea3 = self.select_model(x)
241241

242242
# Calculate selection threshold (top 25%)
243243
patch_fea3_flat = ttnn.reshape(patch_fea3, (-1,))
@@ -281,4 +281,4 @@ def forward(self, x):
281281
sr = ttnn.concat(sr_patches, dim=0)
282282
sr = window_reverse_ttnn(sr, window_size=H, h=H * 4, w=W * 4)
283283

284-
return sr, patch_fea3, patch_fea2, patch_fea1
284+
return sr, patch_fea3

models/experimental/SSR/tt/tile_selection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def forward(self, x):
105105
Args:
106106
x: Input tensor [B, C, H, W]
107107
"""
108+
x = ttnn.permute(x, (0, 2, 3, 1))
108109
B, C, H, W = x.shape
109110

110111
# Patch embedding using existing TTPatchEmbed

0 commit comments

Comments
 (0)