11# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
22# SPDX-License-Identifier: Apache-2.0
33
4- import pytest
54import torch
5+ import pytest
66from loguru import logger
77import ttnn
8- from ttnn .model_preprocessing import preprocess_model_parameters
98
9+ from ttnn .model_preprocessing import preprocess_model_parameters
1010from tests .ttnn .utils_for_testing import check_with_pcc
11- from models .experimental .panoptic_deeplab .tt .custom_preprocessing import create_custom_mesh_preprocessor
1211from models .experimental .panoptic_deeplab .reference .aspp import ASPPModel
1312from models .experimental .panoptic_deeplab .tt .aspp import TTASPP
1413from models .experimental .panoptic_deeplab .common import load_torch_model_state
14+ from models .experimental .panoptic_deeplab .tt .custom_preprocessing import create_custom_mesh_preprocessor
1515
1616
17- class AsppTestInfra :
18- def __init__ (
19- self ,
20- device ,
21- batch_size ,
22- input_channels ,
23- height ,
24- width ,
25- model_config ,
26- name ,
27- ):
17+ class ASPPTestInfra :
18+ def __init__ (self , device , batch_size , input_channels , height , width , model_config , name ):
2819 super ().__init__ ()
2920 if not hasattr (self , "_model_initialized" ):
3021 torch .manual_seed (42 )
31- self ._model_initialized = True
3222 torch .cuda .manual_seed_all (42 )
3323 torch .backends .cudnn .deterministic = True
34- self .pcc_passed = False
35- self .pcc_message = "call validate()?"
24+ self ._model_initialized = True
25+
26+ # Initialize core config
3627 self .device = device
37- self .num_devices = device .get_num_devices ()
3828 self .batch_size = batch_size
39- self .inputs_mesh_mapper , self .weights_mesh_mapper , self .output_mesh_composer = self .get_mesh_mappers (device )
29+ self .input_channels = input_channels
30+ self .height = height
31+ self .width = width
32+ self .model_config = model_config
4033 self .name = name
34+ self .num_devices = device .get_num_devices ()
35+
36+ # Mesh mappers
37+ self .inputs_mesh_mapper , self .weights_mesh_mapper , self .output_mesh_composer = self .get_mesh_mappers (device )
4138
42- # torch model
39+ logger .info (f"Initializing ASPP test for module: { name } " )
40+
41+ # Torch model
4342 torch_model = ASPPModel ()
4443 torch_model = load_torch_model_state (torch_model , name )
4544
45+ # Create synthetic input
46+ self .torch_input_tensor = self ._create_input_tensor ()
47+
48+ # Run torch model
49+ self .torch_output_tensor = torch_model (self .torch_input_tensor )
50+
51+ # Preprocess model
4652 parameters = preprocess_model_parameters (
4753 initialize_model = lambda : torch_model ,
4854 custom_preprocessor = create_custom_mesh_preprocessor (self .weights_mesh_mapper ),
4955 device = None ,
5056 )
5157
52- # golden
53- self .torch_input_tensor = torch .randn ((batch_size , input_channels , height , width ), dtype = torch .float )
54- self .torch_output_tensor = torch_model (self .torch_input_tensor )
58+ # Initialize TTNN model
59+ self .ttnn_model = TTASPP (parameters , model_config )
5560
56- # ttnn
57- tt_host_tensor = ttnn .from_torch (
58- self .torch_input_tensor .permute (0 , 2 , 3 , 1 ),
59- dtype = ttnn .bfloat8_b ,
60- device = device ,
61- mesh_mapper = self .inputs_mesh_mapper ,
62- )
61+ # Prepare TTNN input
62+ logger .info ("Converting input to TTNN tensor..." )
6363
64- self .ttnn_model = TTASPP (parameters , model_config )
65- self .input_tensor = ttnn .to_layout (tt_host_tensor , ttnn .TILE_LAYOUT )
66- self .input_tensor = ttnn .to_device (tt_host_tensor , device , memory_config = ttnn .L1_MEMORY_CONFIG )
64+ # Run model and validate
65+ for phase in ("JIT configuration" , "optimized" ):
66+ logger .info (f"Running TTNN model pass ({ phase } )..." )
67+
68+ # Re-convert input tensor (TTNN may deallocate buffers)
69+ tt_host_tensor = ttnn .from_torch (
70+ self .torch_input_tensor .permute (0 , 2 , 3 , 1 ),
71+ dtype = ttnn .bfloat8_b ,
72+ device = self .device ,
73+ mesh_mapper = self .inputs_mesh_mapper ,
74+ )
75+ self .input_tensor = ttnn .to_device (tt_host_tensor , self .device , memory_config = ttnn .L1_MEMORY_CONFIG )
76+
77+ # Optional: Re-instantiate model if it's not stateless
78+ self .ttnn_model = TTASPP (parameters , self .model_config )
79+
80+ self .run ()
81+ self .validate ()
6782
68- # run and validate
69- self .run ()
70- self .validate ()
83+ def _create_input_tensor (self ):
84+ shape = (self .batch_size * self .num_devices , self .input_channels , self .height , self .width )
85+ logger .info (f"Generating synthetic input tensor of shape { shape } " )
86+ return torch .randn (shape , dtype = torch .float32 )
7187
72- def get_mesh_mappers (self , device ):
88+ @classmethod
89+ def get_mesh_mappers (cls , device ):
7390 if device .get_num_devices () != 1 :
74- inputs_mesh_mapper = ttnn .ShardTensorToMesh (device , dim = 0 )
75- weights_mesh_mapper = None
76- output_mesh_composer = ttnn .ConcatMeshToTensor (device , dim = 0 )
77- else :
78- inputs_mesh_mapper = None
79- weights_mesh_mapper = None
80- output_mesh_composer = None
81- return inputs_mesh_mapper , weights_mesh_mapper , output_mesh_composer
91+ return (
92+ ttnn .ShardTensorToMesh (device , dim = 0 ), # inputs
93+ None , # weights
94+ ttnn .ConcatMeshToTensor (device , dim = 0 ), # outputs
95+ )
96+ return None , None , None
8297
8398 def run (self ):
99+ logger .info ("Running TTNN ASPP model..." )
84100 self .output_tensor = self .ttnn_model (self .input_tensor , self .device )
85101 return self .output_tensor
86102
87- def validate (self , output_tensor = None ):
88- tt_output_tensor = self .output_tensor if output_tensor is None else output_tensor
89- tt_output_tensor_torch = ttnn .to_torch (
90- tt_output_tensor , device = self .device , mesh_composer = self .output_mesh_composer
103+ def _tt_to_torch_nchw (self , tt_tensor , expected_shape ):
104+ torch_tensor = ttnn .to_torch (tt_tensor , device = self .device , mesh_composer = self .output_mesh_composer )
105+ torch_tensor = torch .reshape (
106+ torch_tensor ,
107+ (expected_shape [0 ], expected_shape [2 ], expected_shape [3 ], expected_shape [1 ]),
91108 )
109+ return torch .permute (torch_tensor , (0 , 3 , 1 , 2 ))
92110
93- # Deallocate output tesnors
94- ttnn .deallocate (tt_output_tensor )
111+ def validate (self ):
112+ logger .info ("Validating TTNN output against PyTorch..." )
113+ tt_output_tensor_torch = self ._tt_to_torch_nchw (self .output_tensor , self .torch_output_tensor .shape )
95114
96- expected_shape = self .torch_output_tensor .shape
97- tt_output_tensor_torch = torch .reshape (
98- tt_output_tensor_torch , (expected_shape [0 ], expected_shape [2 ], expected_shape [3 ], expected_shape [1 ])
99- )
100- tt_output_tensor_torch = torch .permute (tt_output_tensor_torch , (0 , 3 , 1 , 2 ))
101-
102- batch_size = tt_output_tensor_torch .shape [0 ]
115+ # Deallocate to save memory
116+ ttnn .deallocate (self .output_tensor )
103117
104- valid_pcc = 0.99
105- self .pcc_passed , self .pcc_message = check_with_pcc (
106- self .torch_output_tensor , tt_output_tensor_torch , pcc = valid_pcc
107- )
118+ pcc_threshold = 0.99
119+ passed , msg = check_with_pcc (self .torch_output_tensor , tt_output_tensor_torch , pcc = pcc_threshold )
120+ assert passed , logger .error (f"ASPP PCC check failed: { msg } " )
108121
109- assert self .pcc_passed , logger .error (f"PCC check failed: { self .pcc_message } " )
110122 logger .info (
111- f"Modular Panoptic DeepLab ASPP layer:{ self .name } - batch_size={ batch_size } , act_dtype={ model_config ['ACTIVATIONS_DTYPE' ]} , weight_dtype={ model_config ['WEIGHTS_DTYPE' ]} , math_fidelity={ model_config ['MATH_FIDELITY' ]} , PCC={ self .pcc_message } "
123+ f"ASPP layer `{ self .name } ` passed: "
124+ f"batch_size={ self .batch_size } , "
125+ f"act_dtype={ self .model_config ['ACTIVATIONS_DTYPE' ]} , "
126+ f"weight_dtype={ self .model_config ['WEIGHTS_DTYPE' ]} , "
127+ f"math_fidelity={ self .model_config ['MATH_FIDELITY' ]} , "
128+ f"PCC={ msg } "
112129 )
113130
114- return self . pcc_passed , self . pcc_message
131+ return True , msg
115132
116133
117134model_config = {
@@ -130,12 +147,12 @@ def validate(self, output_tensor=None):
130147)
131148@pytest .mark .parametrize ("name" , ["semantic_decoder.aspp" , "instance_decoder.aspp" ])
132149def test_aspp (device , batch_size , input_channels , height , width , name ):
133- AsppTestInfra (
134- device ,
135- batch_size ,
136- input_channels ,
137- height ,
138- width ,
139- model_config ,
140- name ,
150+ ASPPTestInfra (
151+ device = device ,
152+ batch_size = batch_size ,
153+ input_channels = input_channels ,
154+ height = height ,
155+ width = width ,
156+ model_config = model_config ,
157+ name = name ,
141158 )
0 commit comments