55import torch
66from loguru import logger
77import ttnn
8- from ttnn .model_preprocessing import (
9- preprocess_model_parameters ,
10- infer_ttnn_module_args ,
11- )
8+ from ttnn .model_preprocessing import preprocess_model_parameters
129from tests .ttnn .utils_for_testing import check_with_pcc
1310
1411from models .experimental .panoptic_deeplab .reference .panoptic_deeplab import TorchPanopticDeepLab
1512from models .experimental .panoptic_deeplab .tt .panoptic_deeplab import TTPanopticDeepLab
1613from models .experimental .panoptic_deeplab .tt .custom_preprocessing import create_custom_mesh_preprocessor
14+ from ttnn .model_preprocessing import infer_ttnn_module_args , preprocess_model_parameters
15+ from models .experimental .panoptic_deeplab .common import load_torch_model_state
1716
1817
1918class PanopticDeepLabTestInfra :
20- _seeded = False
21- _PCC_THRESH = 0.97
22-
23- def __init__ (self , device , batch_size , in_channels , height , width , model_config ):
19+ def __init__ (
20+ self ,
21+ device ,
22+ batch_size ,
23+ in_channels ,
24+ height ,
25+ width ,
26+ model_config ,
27+ ):
2428 super ().__init__ ()
25- self ._maybe_seed ()
29+ if not hasattr (self , "_model_initialized" ):
30+ torch .manual_seed (42 )
31+ self ._model_initialized = True
32+ torch .cuda .manual_seed_all (42 )
33+ torch .backends .cudnn .deterministic = True
2634
27- # Core state
35+ self .pcc_passed = False
36+ self .pcc_message = "call validate()?"
2837 self .device = device
29- self .model_config = model_config
3038 self .num_devices = device .get_num_devices ()
31- self .batch_size , self .in_channels , self .height , self .width = (
32- batch_size ,
33- in_channels ,
34- height ,
35- width ,
36- )
39+ self .batch_size = batch_size
40+ self .in_channels = in_channels
41+ self .height = height
42+ self .width = width
3743 self .inputs_mesh_mapper , self .weights_mesh_mapper , self .output_mesh_composer = self .get_mesh_mappers (device )
3844
39- # Torch reference model + inputs
40- torch_model = TorchPanopticDeepLab ().eval ()
45+ # Initialize torch model
46+ torch_model = TorchPanopticDeepLab ()
47+ torch_model = load_torch_model_state (torch_model , "panoptic_deeplab" )
48+
49+ # Create input tensor
4150 input_shape = (batch_size * self .num_devices , in_channels , height , width )
4251 self .torch_input_tensor = torch .rand (input_shape , dtype = torch .float )
4352
44- # Preprocess TTNN parameters
53+ # Preprocess model parameters
4554 parameters = preprocess_model_parameters (
4655 initialize_model = lambda : torch_model ,
4756 custom_preprocessor = create_custom_mesh_preprocessor (self .weights_mesh_mapper ),
@@ -51,23 +60,27 @@ def __init__(self, device, batch_size, in_channels, height, width, model_config)
5160 # Populate conv_args for decoders via one small warm-up pass
5261 self ._populate_all_decoders (torch_model , parameters )
5362
54- # Run Torch once (fp32) → then bf16 for parity with TTNN
63+ # Run torch model with bfloat16
5564 logger .info ("Running PyTorch model..." )
5665 self .torch_output_tensor , self .torch_output_tensor_2 , self .torch_output_tensor_3 = torch_model (
5766 self .torch_input_tensor
5867 )
5968
60- # Convert input to TTNN NHWC host tensor
69+ # Convert input to TTNN format (NHWC)
6170 logger .info ("Converting input to TTNN format..." )
6271 tt_host_tensor = ttnn .from_torch (
6372 self .torch_input_tensor .permute (0 , 2 , 3 , 1 ),
6473 dtype = ttnn .bfloat16 ,
6574 mesh_mapper = self .inputs_mesh_mapper ,
6675 )
6776
68- # TTNN model
77+ # Initialize TTNN model
6978 logger .info ("Initializing TTNN model..." )
70- self .ttnn_model = TTPanopticDeepLab (parameters = parameters , model_config = model_config )
79+ print ("Initializing TTNN model..." )
80+ self .ttnn_model = TTPanopticDeepLab (
81+ parameters = parameters ,
82+ model_config = model_config ,
83+ )
7184
7285 # First run configures JIT, second run is optimized
7386 for phase in ("JIT configuration" , "optimized" ):
@@ -76,16 +89,7 @@ def __init__(self, device, batch_size, in_channels, height, width, model_config)
7689 self .run ()
7790 self .validate ()
7891
79- # --------------------------- Setup & helpers ---------------------------
80-
8192 @classmethod
82- def _maybe_seed (cls ):
83- if not cls ._seeded :
84- torch .manual_seed (42 )
85- torch .cuda .manual_seed_all (42 )
86- torch .backends .cudnn .deterministic = True
87- cls ._seeded = True
88-
8993 def get_mesh_mappers (self , device ):
9094 if device .get_num_devices () != 1 :
9195 return (
@@ -143,8 +147,6 @@ def _tt_to_torch_nchw(tt_tensor, device, mesh_composer, expected_shape):
143147 t = torch .reshape (t , (expected_shape [0 ], expected_shape [2 ], expected_shape [3 ], expected_shape [1 ]))
144148 return torch .permute (t , (0 , 3 , 1 , 2 ))
145149
146- # --------------------------- Core runs/validation ---------------------------
147-
148150 def run (self ):
149151 self .output_tensor , self .output_tensor_2 , self .output_tensor_3 = self .ttnn_model (self .input_tensor , self .device )
150152 return self .output_tensor , self .output_tensor_2 , self .output_tensor_3
@@ -157,24 +159,24 @@ def validate(self):
157159 ("Instance Segmentation Center Head" , self .output_tensor_3 , self .torch_output_tensor_3 ),
158160 ]
159161
162+ self ._PCC_THRESH = 0.97
163+
160164 for name , tt_out , torch_ref in checks :
161165 out = self ._tt_to_torch_nchw (tt_out , self .device , self .output_mesh_composer , torch_ref .shape )
162166 passed , msg = check_with_pcc (torch_ref , out , pcc = self ._PCC_THRESH )
163167 assert passed , logger .error (f"{ name } PCC check failed: { msg } " )
164168
165169 logger .info (
166170 f"Panoptic DeepLab - { name } : batch_size={ self .batch_size } , "
167- f"act_dtype={ self . model_config ['ACTIVATIONS_DTYPE' ]} , "
168- f"weight_dtype={ self . model_config ['WEIGHTS_DTYPE' ]} , "
169- f"math_fidelity={ self . model_config ['MATH_FIDELITY' ]} , "
171+ f"act_dtype={ model_config ['ACTIVATIONS_DTYPE' ]} , "
172+ f"weight_dtype={ model_config ['WEIGHTS_DTYPE' ]} , "
173+ f"math_fidelity={ model_config ['MATH_FIDELITY' ]} , "
170174 f"PCC={ msg } , shape={ tt_out .shape } "
171175 )
172176
173177 return True , f"All heads passed PCC ≥ { self ._PCC_THRESH } "
174178
175179
176- # --------------------------- Test config ---------------------------
177-
178180model_config = {
179181 "MATH_FIDELITY" : ttnn .MathFidelity .LoFi ,
180182 "WEIGHTS_DTYPE" : ttnn .bfloat8_b ,
@@ -183,6 +185,24 @@ def validate(self):
183185
184186
185187@pytest .mark .parametrize ("device_params" , [{"l1_small_size" : 16384 }], indirect = True )
186- @pytest .mark .parametrize ("batch_size, in_channels, height, width" , [(1 , 3 , 512 , 1024 )])
187- def test_panoptic_deeplab (device , batch_size , in_channels , height , width ):
188- PanopticDeepLabTestInfra (device , batch_size , in_channels , height , width , model_config )
188+ @pytest .mark .parametrize (
189+ "batch_size, in_channels, height, width" ,
190+ [
191+ (1 , 3 , 512 , 1024 ),
192+ ],
193+ )
194+ def test_panoptic_deeplab (
195+ device ,
196+ batch_size ,
197+ in_channels ,
198+ height ,
199+ width ,
200+ ):
201+ PanopticDeepLabTestInfra (
202+ device ,
203+ batch_size ,
204+ in_channels ,
205+ height ,
206+ width ,
207+ model_config ,
208+ )
0 commit comments