55import torch
66from loguru import logger
77import ttnn
8- from ttnn .model_preprocessing import preprocess_model_parameters
8+ from ttnn .model_preprocessing import (
9+ preprocess_model_parameters ,
10+ infer_ttnn_module_args ,
11+ )
912from tests .ttnn .utils_for_testing import check_with_pcc
1013
1114from models .experimental .panoptic_deeplab .reference .panoptic_deeplab import TorchPanopticDeepLab
1215from models .experimental .panoptic_deeplab .tt .panoptic_deeplab import TTPanopticDeepLab
1316from models .experimental .panoptic_deeplab .tt .custom_preprocessing import create_custom_mesh_preprocessor
14- from ttnn .model_preprocessing import infer_ttnn_module_args , preprocess_model_parameters
15- from models .experimental .panoptic_deeplab .common import load_torch_model_state
1617
1718
1819class PanopticDeepLabTestInfra :
19- def __init__ (
20- self ,
21- device ,
22- batch_size ,
23- in_channels ,
24- height ,
25- width ,
26- model_config ,
27- ):
20+ _seeded = False
21+ _PCC_THRESH = 0.97
22+
23+ def __init__ (self , device , batch_size , in_channels , height , width , model_config ):
2824 super ().__init__ ()
29- if not hasattr (self , "_model_initialized" ):
30- torch .manual_seed (42 )
31- self ._model_initialized = True
32- torch .cuda .manual_seed_all (42 )
33- torch .backends .cudnn .deterministic = True
25+ self ._maybe_seed ()
3426
35- self .pcc_passed = False
36- self .pcc_message = "call validate()?"
27+ # Core state
3728 self .device = device
29+ self .model_config = model_config
3830 self .num_devices = device .get_num_devices ()
39- self .batch_size = batch_size
40- self .in_channels = in_channels
41- self .height = height
42- self .width = width
31+ self .batch_size , self .in_channels , self .height , self .width = (
32+ batch_size ,
33+ in_channels ,
34+ height ,
35+ width ,
36+ )
4337 self .inputs_mesh_mapper , self .weights_mesh_mapper , self .output_mesh_composer = self .get_mesh_mappers (device )
4438
45- # Initialize torch model
46- torch_model = TorchPanopticDeepLab ()
47- torch_model = load_torch_model_state (torch_model , "panoptic_deeplab" )
48-
49- # Create input tensor
39+ # Torch reference model + inputs
40+ torch_model = TorchPanopticDeepLab ().eval ()
5041 input_shape = (batch_size * self .num_devices , in_channels , height , width )
5142 self .torch_input_tensor = torch .rand (input_shape , dtype = torch .float )
5243
53- # Preprocess model parameters
44+ # Preprocess TTNN parameters
5445 parameters = preprocess_model_parameters (
5546 initialize_model = lambda : torch_model ,
5647 custom_preprocessor = create_custom_mesh_preprocessor (self .weights_mesh_mapper ),
5748 device = None ,
5849 )
5950
60- parameters .conv_args = {}
61- input_tensor = torch .randn (1 , 2048 , 32 , 64 )
62- res3_tensor = torch .randn (1 , 512 , 64 , 128 )
63- res2_tensor = torch .randn (1 , 256 , 128 , 256 )
51+ # Populate conv_args for decoders via one small warm-up pass
52+ self ._populate_all_decoders (torch_model , parameters )
6453
65- # For semantic decoder
66- if hasattr (parameters , "semantic_decoder" ):
67- # ASPP
68- aspp_args = infer_ttnn_module_args (
69- model = torch_model .semantic_decoder .aspp , run_model = lambda model : model (input_tensor ), device = None
70- )
71- if hasattr (parameters .semantic_decoder , "aspp" ):
72- parameters .semantic_decoder .aspp .conv_args = aspp_args
73-
74- # Res3
75- aspp_out = torch_model .semantic_decoder .aspp (input_tensor )
76- res3_args = infer_ttnn_module_args (
77- model = torch_model .semantic_decoder .res3 ,
78- run_model = lambda model : model (aspp_out , res3_tensor ),
79- device = None ,
80- )
81- if hasattr (parameters .semantic_decoder , "res3" ):
82- parameters .semantic_decoder .res3 .conv_args = res3_args
83-
84- # Res2
85- res3_out = torch_model .semantic_decoder .res3 (aspp_out , res3_tensor )
86- res2_args = infer_ttnn_module_args (
87- model = torch_model .semantic_decoder .res2 ,
88- run_model = lambda model : model (res3_out , res2_tensor ),
89- device = None ,
90- )
91- if hasattr (parameters .semantic_decoder , "res2" ):
92- parameters .semantic_decoder .res2 .conv_args = res2_args
93-
94- # Head
95- res2_out = torch_model .semantic_decoder .res2 (res3_out , res2_tensor )
96- head_args = infer_ttnn_module_args (
97- model = torch_model .semantic_decoder .head_1 , run_model = lambda model : model (res2_out ), device = None
98- )
99- if hasattr (parameters .semantic_decoder , "head_1" ):
100- parameters .semantic_decoder .head_1 .conv_args = head_args
101-
102- # For instance decoder
103- if hasattr (parameters , "instance_decoder" ):
104- # ASPP
105- aspp_args = infer_ttnn_module_args (
106- model = torch_model .instance_decoder .aspp , run_model = lambda model : model (input_tensor ), device = None
107- )
108- if hasattr (parameters .instance_decoder , "aspp" ):
109- parameters .instance_decoder .aspp .conv_args = aspp_args
110-
111- # Res3
112- aspp_out = torch_model .instance_decoder .aspp (input_tensor )
113- res3_args = infer_ttnn_module_args (
114- model = torch_model .instance_decoder .res3 ,
115- run_model = lambda model : model (aspp_out , res3_tensor ),
116- device = None ,
117- )
118- if hasattr (parameters .instance_decoder , "res3" ):
119- parameters .instance_decoder .res3 .conv_args = res3_args
120-
121- # Res2
122- res3_out = torch_model .instance_decoder .res3 (aspp_out , res3_tensor )
123- res2_args = infer_ttnn_module_args (
124- model = torch_model .instance_decoder .res2 ,
125- run_model = lambda model : model (res3_out , res2_tensor ),
126- device = None ,
127- )
128- if hasattr (parameters .instance_decoder , "res2" ):
129- parameters .instance_decoder .res2 .conv_args = res2_args
130-
131- # Head
132- res2_out = torch_model .instance_decoder .res2 (res3_out , res2_tensor )
133- head_args_1 = infer_ttnn_module_args (
134- model = torch_model .instance_decoder .head_1 , run_model = lambda model : model (res2_out ), device = None
135- )
136- head_args_2 = infer_ttnn_module_args (
137- model = torch_model .instance_decoder .head_2 , run_model = lambda model : model (res2_out ), device = None
138- )
139- if hasattr (parameters .instance_decoder , "head_1" ):
140- parameters .instance_decoder .head_1 .conv_args = head_args_1
141- if hasattr (parameters .instance_decoder , "head_2" ):
142- parameters .instance_decoder .head_2 .conv_args = head_args_2
143-
144- # Run torch model with bfloat16
54+ # Run Torch once (fp32) → then bf16 for parity with TTNN
14555 logger .info ("Running PyTorch model..." )
14656 self .torch_output_tensor , self .torch_output_tensor_2 , self .torch_output_tensor_3 = torch_model (
14757 self .torch_input_tensor
14858 )
14959
150- # Convert input to TTNN format ( NHWC)
60+ # Convert input to TTNN NHWC host tensor
15161 logger .info ("Converting input to TTNN format..." )
15262 tt_host_tensor = ttnn .from_torch (
15363 self .torch_input_tensor .permute (0 , 2 , 3 , 1 ),
15464 dtype = ttnn .bfloat16 ,
15565 mesh_mapper = self .inputs_mesh_mapper ,
15666 )
15767
158- # Initialize TTNN model
68+ # TTNN model
15969 logger .info ("Initializing TTNN model..." )
160- print ("Initializing TTNN model..." )
161- self .ttnn_model = TTPanopticDeepLab (
162- parameters = parameters ,
163- model_config = model_config ,
164- )
70+ self .ttnn_model = TTPanopticDeepLab (parameters = parameters , model_config = model_config )
16571
166- logger .info ("Running first TTNN model pass (JIT configuration)..." )
167- # first run configures convs JIT
168- self .input_tensor = ttnn .to_device (tt_host_tensor , device )
169- self .run ()
170- self .validate ()
72+ # First run configures JIT, second run is optimized
73+ for phase in ("JIT configuration" , "optimized" ):
74+ logger .info (f"Running TTNN model pass ({ phase } )..." )
75+ self .input_tensor = ttnn .to_device (tt_host_tensor , device )
76+ self .run ()
77+ self .validate ()
17178
172- logger .info ("Running optimized TTNN model pass..." )
173- # Optimized run
174- self .input_tensor = ttnn .to_device (tt_host_tensor , device )
175- self .run ()
176- self .validate ()
79+ # --------------------------- Setup & helpers ---------------------------
80+
81+ @classmethod
82+ def _maybe_seed (cls ):
83+ if not cls ._seeded :
84+ torch .manual_seed (42 )
85+ torch .cuda .manual_seed_all (42 )
86+ torch .backends .cudnn .deterministic = True
87+ cls ._seeded = True
17788
17889 def get_mesh_mappers (self , device ):
17990 if device .get_num_devices () != 1 :
180- inputs_mesh_mapper = ttnn .ShardTensorToMesh (device , dim = 0 )
181- weights_mesh_mapper = None
182- output_mesh_composer = ttnn .ConcatMeshToTensor (device , dim = 0 )
183- else :
184- inputs_mesh_mapper = None
185- weights_mesh_mapper = None
186- output_mesh_composer = None
187- return inputs_mesh_mapper , weights_mesh_mapper , output_mesh_composer
91+ return (
92+ ttnn .ShardTensorToMesh (device , dim = 0 ), # inputs
93+ None , # weights
94+ ttnn .ConcatMeshToTensor (device , dim = 0 ), # outputs
95+ )
96+ return None , None , None
18897
189- def run (self ):
190- self .output_tensor , self .output_tensor_2 , self .output_tensor_3 = self .ttnn_model (self .input_tensor , self .device )
191- return self .output_tensor , self .output_tensor_2 , self .output_tensor_3
98+ @staticmethod
99+ def _infer_and_set (module , params_holder , attr_name , run_fn ):
100+ """Infer conv args for a TTNN module and set them if present in parameters."""
101+ if hasattr (params_holder , attr_name ):
102+ args = infer_ttnn_module_args (model = module , run_model = run_fn , device = None )
103+ getattr (params_holder , attr_name ).conv_args = args
192104
193- def validate (self , output_tensor = None ):
194- output_tensor = self .output_tensor if output_tensor is None else output_tensor
195- output_tensor = ttnn .to_torch (output_tensor , device = self .device , mesh_composer = self .output_mesh_composer )
196- expected_shape = self .torch_output_tensor .shape
197- output_tensor = torch .reshape (
198- output_tensor , (expected_shape [0 ], expected_shape [2 ], expected_shape [3 ], expected_shape [1 ])
199- )
200- output_tensor = torch .permute (output_tensor , (0 , 3 , 1 , 2 ))
105+ def _populate_decoder (self , torch_dec , params_dec ):
106+ """Warm up a single decoder (semantic or instance) to populate conv_args."""
107+ if not (torch_dec and params_dec ):
108+ return
201109
202- batch_size = output_tensor .shape [0 ]
110+ # Synthetic tensors that match typical Panoptic-DeepLab strides
111+ input_tensor = torch .randn (1 , 2048 , 32 , 64 )
112+ res3_tensor = torch .randn (1 , 512 , 64 , 128 )
113+ res2_tensor = torch .randn (1 , 256 , 128 , 256 )
203114
204- valid_pcc = 0.97
205- self .pcc_passed , self .pcc_message = check_with_pcc (self .torch_output_tensor , output_tensor , pcc = valid_pcc )
206- assert self .pcc_passed , logger .error (f"Semantic Segmentation Head PCC check failed: { self .pcc_message } " )
207- logger .info (
208- f"Panoptic DeepLab - Semantic Segmentation Head: batch_size={ self .batch_size } , "
209- f"act_dtype={ model_config ['ACTIVATIONS_DTYPE' ]} , weight_dtype={ model_config ['WEIGHTS_DTYPE' ]} , "
210- f"math_fidelity={ model_config ['MATH_FIDELITY' ]} , PCC={ self .pcc_message } , shape={ self .output_tensor .shape } "
211- )
115+ # ASPP
116+ self ._infer_and_set (torch_dec .aspp , params_dec , "aspp" , lambda m : m (input_tensor ))
117+ aspp_out = torch_dec .aspp (input_tensor )
212118
213- # Validate instance segmentation head outputs
214- output_tensor = self .output_tensor_2
215- output_tensor = ttnn .to_torch (output_tensor , device = self .device , mesh_composer = self .output_mesh_composer )
216- expected_shape = self .torch_output_tensor_2 .shape
217- output_tensor = torch .reshape (
218- output_tensor , (expected_shape [0 ], expected_shape [2 ], expected_shape [3 ], expected_shape [1 ])
219- )
220- output_tensor = torch .permute (output_tensor , (0 , 3 , 1 , 2 ))
119+ # res3
120+ self ._infer_and_set (torch_dec .res3 , params_dec , "res3" , lambda m : m (aspp_out , res3_tensor ))
121+ res3_out = torch_dec .res3 (aspp_out , res3_tensor )
221122
222- batch_size = output_tensor .shape [0 ]
123+ # res2
124+ self ._infer_and_set (torch_dec .res2 , params_dec , "res2" , lambda m : m (res3_out , res2_tensor ))
125+ res2_out = torch_dec .res2 (res3_out , res2_tensor )
223126
224- valid_pcc = 0.97
225- self .pcc_passed , self .pcc_message = check_with_pcc (self .torch_output_tensor_2 , output_tensor , pcc = valid_pcc )
226- assert self .pcc_passed , logger .error (f"Instance Segmentation Head PCC check failed: { self .pcc_message } " )
227- logger .info (
228- f"Panoptic DeepLab - Instance Segmentation Offset Head: batch_size={ self .batch_size } , "
229- f"act_dtype={ model_config ['ACTIVATIONS_DTYPE' ]} , weight_dtype={ model_config ['WEIGHTS_DTYPE' ]} , "
230- f"math_fidelity={ model_config ['MATH_FIDELITY' ]} , PCC={ self .pcc_message } , shape={ self .output_tensor_2 .shape } "
231- )
127+ # heads (one or two, if present)
128+ if hasattr (torch_dec , "head_1" ):
129+ self ._infer_and_set (torch_dec .head_1 , params_dec , "head_1" , lambda m : m (res2_out ))
130+ if hasattr (torch_dec , "head_2" ):
131+ self ._infer_and_set (torch_dec .head_2 , params_dec , "head_2" , lambda m : m (res2_out ))
232132
233- output_tensor = self .output_tensor_3
234- output_tensor = ttnn .to_torch (output_tensor , device = self .device , mesh_composer = self .output_mesh_composer )
235- expected_shape = self .torch_output_tensor_3 .shape
236- output_tensor = torch .reshape (
237- output_tensor , (expected_shape [0 ], expected_shape [2 ], expected_shape [3 ], expected_shape [1 ])
238- )
239- output_tensor = torch .permute (output_tensor , (0 , 3 , 1 , 2 ))
133+ def _populate_all_decoders (self , torch_model , parameters ):
134+ if hasattr (parameters , "semantic_decoder" ):
135+ self ._populate_decoder (torch_model .semantic_decoder , parameters .semantic_decoder )
136+ if hasattr (parameters , "instance_decoder" ):
137+ self ._populate_decoder (torch_model .instance_decoder , parameters .instance_decoder )
240138
241- batch_size = output_tensor .shape [0 ]
139+ @staticmethod
140+ def _tt_to_torch_nchw (tt_tensor , device , mesh_composer , expected_shape ):
141+ """Convert TTNN NHWC tensor back to Torch NCHW and reshape to expected batch/shape."""
142+ t = ttnn .to_torch (tt_tensor , device = device , mesh_composer = mesh_composer )
143+ t = torch .reshape (t , (expected_shape [0 ], expected_shape [2 ], expected_shape [3 ], expected_shape [1 ]))
144+ return torch .permute (t , (0 , 3 , 1 , 2 ))
242145
243- valid_pcc = 0.97
244- self .pcc_passed , self .pcc_message = check_with_pcc (self .torch_output_tensor_3 , output_tensor , pcc = valid_pcc )
245- assert self .pcc_passed , logger .error (f"Instance Segmentation Head 2 PCC check failed: { self .pcc_message } " )
246- logger .info (
247- f"Panoptic DeepLab - Instance Segmentation Center Head: batch_size={ self .batch_size } , "
248- f"act_dtype={ model_config ['ACTIVATIONS_DTYPE' ]} , weight_dtype={ model_config ['WEIGHTS_DTYPE' ]} , "
249- f"math_fidelity={ model_config ['MATH_FIDELITY' ]} , PCC={ self .pcc_message } , shape={ self .output_tensor_3 .shape } "
250- )
146+ # --------------------------- Core runs/validation ---------------------------
251147
252- return self .pcc_passed , self .pcc_message
148+ def run (self ):
149+ self .output_tensor , self .output_tensor_2 , self .output_tensor_3 = self .ttnn_model (self .input_tensor , self .device )
150+ return self .output_tensor , self .output_tensor_2 , self .output_tensor_3
151+
152+ def validate (self ):
153+ """Validate three heads (semantic, offsets, centers) in a uniform loop."""
154+ checks = [
155+ ("Semantic Segmentation Head" , self .output_tensor , self .torch_output_tensor ),
156+ ("Instance Segmentation Offset Head" , self .output_tensor_2 , self .torch_output_tensor_2 ),
157+ ("Instance Segmentation Center Head" , self .output_tensor_3 , self .torch_output_tensor_3 ),
158+ ]
159+
160+ for name , tt_out , torch_ref in checks :
161+ out = self ._tt_to_torch_nchw (tt_out , self .device , self .output_mesh_composer , torch_ref .shape )
162+ passed , msg = check_with_pcc (torch_ref , out , pcc = self ._PCC_THRESH )
163+ assert passed , logger .error (f"{ name } PCC check failed: { msg } " )
164+
165+ logger .info (
166+ f"Panoptic DeepLab - { name } : batch_size={ self .batch_size } , "
167+ f"act_dtype={ self .model_config ['ACTIVATIONS_DTYPE' ]} , "
168+ f"weight_dtype={ self .model_config ['WEIGHTS_DTYPE' ]} , "
169+ f"math_fidelity={ self .model_config ['MATH_FIDELITY' ]} , "
170+ f"PCC={ msg } , shape={ tt_out .shape } "
171+ )
253172
173+ return True , f"All heads passed PCC ≥ { self ._PCC_THRESH } "
174+
175+
176+ # --------------------------- Test config ---------------------------
254177
255178model_config = {
256179 "MATH_FIDELITY" : ttnn .MathFidelity .LoFi ,
@@ -260,24 +183,6 @@ def validate(self, output_tensor=None):
260183
261184
262185@pytest .mark .parametrize ("device_params" , [{"l1_small_size" : 16384 }], indirect = True )
263- @pytest .mark .parametrize (
264- "batch_size, in_channels, height, width" ,
265- [
266- (1 , 3 , 512 , 1024 ),
267- ],
268- )
269- def test_panoptic_deeplab (
270- device ,
271- batch_size ,
272- in_channels ,
273- height ,
274- width ,
275- ):
276- PanopticDeepLabTestInfra (
277- device ,
278- batch_size ,
279- in_channels ,
280- height ,
281- width ,
282- model_config ,
283- )
186+ @pytest .mark .parametrize ("batch_size, in_channels, height, width" , [(1 , 3 , 512 , 1024 )])
187+ def test_panoptic_deeplab (device , batch_size , in_channels , height , width ):
188+ PanopticDeepLabTestInfra (device , batch_size , in_channels , height , width , model_config )
0 commit comments