@@ -52,7 +52,9 @@ def custom_preprocessor(torch_model, name, ttnn_module_args):
5252 forward_params = {"rpi_sa" : tt_rpi_sa , "attn_mask" : attn_mask , "rpi_oca" : tt_rpi_oca }
5353 sr_params = preprocess_model_parameters (
5454 initialize_model = lambda : torch_model .sr_model ,
55- custom_preprocessor = create_tile_refinement_preprocessor (device , forward_params ),
55+ custom_preprocessor = create_tile_refinement_preprocessor (
56+ device , forward_params , window_size = 16 , rpi_sa = rpi_sa
57+ ),
5658 device = device ,
5759 )
5860 parameters ["sr_model" ] = sr_params
@@ -104,7 +106,7 @@ def __init__(self):
104106@pytest .mark .parametrize (
105107 "input_shape, num_cls, with_conv" ,
106108 [
107- ((1 , 3 , 256 , 256 ), 1 , True ),
109+ # ((1, 3, 256, 256), 1, True),
108110 ((1 , 3 , 256 , 256 ), 1 , False ),
109111 ],
110112)
@@ -156,27 +158,21 @@ def test_ssr_model(input_shape, num_cls, with_conv):
156158 tt_input = ttnn .from_torch (x , device = device , layout = ttnn .TILE_LAYOUT )
157159
158160 # Run TTNN model
159- tt_sr , tt_patch_fea3 , tt_patch_fea2 , tt_patch_fea1 = tt_model (tt_input )
161+ tt_sr , tt_patch_fea3 = tt_model (tt_input )
160162
161163 # Convert back to torch tensors
162164 tt_torch_sr = tt2torch_tensor (tt_sr )
163165 tt_torch_patch_fea3 = tt2torch_tensor (tt_patch_fea3 )
164- tt_torch_patch_fea2 = tt2torch_tensor (tt_patch_fea2 )
165- tt_torch_patch_fea1 = tt2torch_tensor (tt_patch_fea1 )
166166 tt_torch_sr = tt_torch_sr .permute (0 , 3 , 1 , 2 )
167167
168168 # Compare outputs
169169 sr_pass , sr_pcc_message = comp_pcc (ref_sr , tt_torch_sr , 0.95 )
170170 fea3_pass , fea3_pcc_message = comp_pcc (ref_patch_fea3 , tt_torch_patch_fea3 , 0.95 )
171- fea2_pass , fea2_pcc_message = comp_pcc (ref_patch_fea2 , tt_torch_patch_fea2 , 0.95 )
172- fea1_pass , fea1_pcc_message = comp_pcc (ref_patch_fea1 , tt_torch_patch_fea1 , 0.95 )
173171
174172 logger .info (f"SR Output PCC: { sr_pcc_message } " )
175173 logger .info (f"Patch Fea3 PCC: { fea3_pcc_message } " )
176- logger .info (f"Patch Fea2 PCC: { fea2_pcc_message } " )
177- logger .info (f"Patch Fea1 PCC: { fea1_pcc_message } " )
178174
179- all_pass = sr_pass and fea3_pass and fea2_pass and fea1_pass
175+ all_pass = sr_pass and fea3_pass
180176
181177 if all_pass :
182178 logger .info ("TTSSR Test Passed!" )
@@ -185,8 +181,6 @@ def test_ssr_model(input_shape, num_cls, with_conv):
185181
186182 assert sr_pass , f"SR output comparison failed: { sr_pcc_message } "
187183 assert fea3_pass , f"Patch fea3 comparison failed: { fea3_pcc_message } "
188- assert fea2_pass , f"Patch fea2 comparison failed: { fea2_pcc_message } "
189- assert fea1_pass , f"Patch fea1 comparison failed: { fea1_pcc_message } "
190184
191185 finally :
192186 ttnn .close_device (device )
0 commit comments