11import pytest
22import torch
33import ttnn
4- from tests .ttnn .utils_for_testing import assert_with_pcc
4+ from loguru import logger
5+ from models .utility_functions import comp_pcc
6+
57from models .utility_functions import torch_random
68
79from models .experimental .SSR .reference .SSR .model .tile_refinement import PatchUnEmbed
810from models .experimental .SSR .tt .patch_unembed import TTPatchUnEmbed
911
1012
11- @pytest .mark .parametrize ("batch_size" , [1 , 2 , 8 ])
12- @pytest .mark .parametrize ("img_size" , [64 ])
13- @pytest .mark .parametrize ("patch_size" , [1 ])
14- @pytest .mark .parametrize ("in_chans" , [180 ])
15- @pytest .mark .parametrize ("embed_dim" , [180 ])
13+ @pytest .mark .parametrize (
14+ "batch_size, img_size, patch_size, in_chans, embed_dim" ,
15+ [
16+ # (1, 64, 4, 3, 180), # TR blk test
17+ (1 , 64 , 2 , 3 , 180 ), # HAT blk test
18+ ],
19+ )
1620def test_tt_patch_unembed (device , batch_size , img_size , patch_size , in_chans , embed_dim ):
1721 """Test TTPatchUnEmbed against PyTorch reference implementation"""
1822 torch .manual_seed (0 )
@@ -37,17 +41,29 @@ def test_tt_patch_unembed(device, batch_size, img_size, patch_size, in_chans, em
3741
3842 # Convert input to TTNN format
3943 ttnn_input = ttnn .from_torch (
40- torch_input , dtype = ttnn .bfloat16 , layout = ttnn .TILE_LAYOUT , device = device , memory_config = ttnn .DRAM_MEMORY_CONFIG
44+ torch_input , dtype = ttnn .bfloat16 , layout = ttnn .TILE_LAYOUT , device = device , memory_config = ttnn .L1_MEMORY_CONFIG
4145 )
4246
4347 # Run TTNN model
4448 ttnn_output = tt_model (ttnn_input , patches_resolution )
4549 ttnn_output_torch = ttnn .to_torch (ttnn_output )
4650
51+ # Compare outputs
52+ does_pass , pcc_message = comp_pcc (torch_output , ttnn_output_torch , 0.99 )
53+
54+ logger .info (f"Reference output shape: { torch_output .shape } " )
55+ logger .info (f"TTNN output shape: { ttnn_output_torch .shape } " )
56+ logger .info (pcc_message )
57+
58+ if does_pass :
59+ logger .info ("TR PatchEmbed Passed!" )
60+ else :
61+ logger .warning ("TR PatchEmbed Failed!" )
62+
63+ assert does_pass , f"PCC check failed: { pcc_message } "
4764 # Assert shapes match
4865 assert (
4966 torch_output .shape == ttnn_output_torch .shape
5067 ), f"Shape mismatch: { torch_output .shape } vs { ttnn_output_torch .shape } "
5168
52- # Assert values match with PCC
53- assert_with_pcc (torch_output , ttnn_output_torch , 0.999 )
69+ ttnn .close_device (device )
0 commit comments