88import pickle
99import numpy as np
1010import os
11+ from PIL import Image
12+ from typing import Tuple
13+ import torchvision .transforms as transforms
14+ from typing import Optional , Any
15+ import ttnn
1116from models .experimental .panoptic_deeplab .reference .resnet52_backbone import ResNet52BackBone as TorchBackbone
1217from models .experimental .panoptic_deeplab .reference .resnet52_stem import DeepLabStem
1318from torchvision .models .resnet import Bottleneck
@@ -190,12 +195,12 @@ def load_torch_model_state(torch_model: torch.nn.Module = None, layer_name: str
190195 model_path = model_location_generator ("vision-models/panoptic_deeplab" , model_subdir = "" , download_if_ci_v2 = True )
191196 if model_path == "models" :
192197 if not os .path .exists (
193- "models/experimental/panoptic_deeplab/reference /Panoptic_Deeplab_R52.pkl"
198+ "models/experimental/panoptic_deeplab/resources /Panoptic_Deeplab_R52.pkl"
194199 ): # check if Panoptic_Deeplab_R52.pkl is available
195200 os .system (
196- "models/experimental/panoptic_deeplab/reference /panoptic_deeplab_weights_download.sh"
201+ "models/experimental/panoptic_deeplab/resources /panoptic_deeplab_weights_download.sh"
197202 ) # execute the panoptic_deeplab_weights_download.sh file
198- weights_path = "models/experimental/panoptic_deeplab/reference /Panoptic_Deeplab_R52.pkl"
203+ weights_path = "models/experimental/panoptic_deeplab/resources /Panoptic_Deeplab_R52.pkl"
199204 else :
200205 weights_path = os .path .join (model_path , "Panoptic_Deeplab_R52.pkl" )
201206
@@ -209,7 +214,6 @@ def load_torch_model_state(torch_model: torch.nn.Module = None, layer_name: str
209214 if isinstance (v , np .ndarray ) or isinstance (v , np .array ):
210215 state_dict [k ] = torch .from_numpy (v )
211216 converted_count += 1
212- logger .debug (f"Converted { converted_count } numpy arrays to torch tensors" )
213217
214218 # Get keys
215219 checkpoint_keys = set (state_dict .keys ())
@@ -225,6 +229,9 @@ def load_torch_model_state(torch_model: torch.nn.Module = None, layer_name: str
225229 mapped_state_dict = {}
226230 for checkpoint_key , model_key in key_mapping .items ():
227231 mapped_state_dict [model_key ] = state_dict [checkpoint_key ]
232+ del mapped_state_dict ["pixel_mean" ]
233+ del mapped_state_dict ["pixel_std" ]
234+ logger .debug (f"Mapped { len (mapped_state_dict )} weights" )
228235
229236 if isinstance (
230237 torch_model ,
@@ -240,10 +247,162 @@ def load_torch_model_state(torch_model: torch.nn.Module = None, layer_name: str
240247 ):
241248 torch_model = load_partial_state (torch_model , mapped_state_dict , layer_name )
242249 elif isinstance (torch_model , TorchPanopticDeepLab ):
243- del mapped_state_dict ["pixel_mean" ]
244- del mapped_state_dict ["pixel_std" ]
245250 torch_model .load_state_dict (mapped_state_dict , strict = True )
246251 else :
247252 raise NotImplementedError ("Unknown torch model. Weight loading not implemented" )
248253
249254 return torch_model .eval ()
255+
256+
257+ def parameter_conv_args (torch_model : torch .nn .Module = None , parameters : dict = None ):
258+ from ttnn .model_preprocessing import infer_ttnn_module_args
259+
260+ if isinstance (torch_model , TorchPanopticDeepLab ):
261+ parameters .conv_args = {}
262+ sample_x = torch .randn (1 , 2048 , 32 , 64 )
263+ sample_res3 = torch .randn (1 , 512 , 64 , 128 )
264+ sample_res2 = torch .randn (1 , 256 , 128 , 256 )
265+
266+ # For semantic decoder
267+ if hasattr (parameters , "semantic_decoder" ):
268+ # ASPP
269+ aspp_args = infer_ttnn_module_args (
270+ model = torch_model .semantic_decoder .aspp , run_model = lambda model : model (sample_x ), device = None
271+ )
272+ if hasattr (parameters .semantic_decoder , "aspp" ):
273+ parameters .semantic_decoder .aspp .conv_args = aspp_args
274+
275+ # Res3
276+ aspp_out = torch_model .semantic_decoder .aspp (sample_x )
277+ res3_args = infer_ttnn_module_args (
278+ model = torch_model .semantic_decoder .res3 ,
279+ run_model = lambda model : model (aspp_out , sample_res3 ),
280+ device = None ,
281+ )
282+ if hasattr (parameters .semantic_decoder , "res3" ):
283+ parameters .semantic_decoder .res3 .conv_args = res3_args
284+
285+ # Res2
286+ res3_out = torch_model .semantic_decoder .res3 (aspp_out , sample_res3 )
287+ res2_args = infer_ttnn_module_args (
288+ model = torch_model .semantic_decoder .res2 ,
289+ run_model = lambda model : model (res3_out , sample_res2 ),
290+ device = None ,
291+ )
292+ if hasattr (parameters .semantic_decoder , "res2" ):
293+ parameters .semantic_decoder .res2 .conv_args = res2_args
294+
295+ # Head
296+ res2_out = torch_model .semantic_decoder .res2 (res3_out , sample_res2 )
297+ head_args = infer_ttnn_module_args (
298+ model = torch_model .semantic_decoder .head_1 , run_model = lambda model : model (res2_out ), device = None
299+ )
300+ if hasattr (parameters .semantic_decoder , "head_1" ):
301+ parameters .semantic_decoder .head_1 .conv_args = head_args
302+
303+ # For instance decoder
304+ if hasattr (parameters , "instance_decoder" ):
305+ # ASPP
306+ aspp_args = infer_ttnn_module_args (
307+ model = torch_model .instance_decoder .aspp , run_model = lambda model : model (sample_x ), device = None
308+ )
309+ if hasattr (parameters .instance_decoder , "aspp" ):
310+ parameters .instance_decoder .aspp .conv_args = aspp_args
311+
312+ # Res3
313+ aspp_out = torch_model .instance_decoder .aspp (sample_x )
314+ res3_args = infer_ttnn_module_args (
315+ model = torch_model .instance_decoder .res3 ,
316+ run_model = lambda model : model (aspp_out , sample_res3 ),
317+ device = None ,
318+ )
319+ if hasattr (parameters .instance_decoder , "res3" ):
320+ parameters .instance_decoder .res3 .conv_args = res3_args
321+
322+ # Res2
323+ res3_out = torch_model .instance_decoder .res3 (aspp_out , sample_res3 )
324+ res2_args = infer_ttnn_module_args (
325+ model = torch_model .instance_decoder .res2 ,
326+ run_model = lambda model : model (res3_out , sample_res2 ),
327+ device = None ,
328+ )
329+ if hasattr (parameters .instance_decoder , "res2" ):
330+ parameters .instance_decoder .res2 .conv_args = res2_args
331+
332+ # Head
333+ res2_out = torch_model .instance_decoder .res2 (res3_out , sample_res2 )
334+ head_args_1 = infer_ttnn_module_args (
335+ model = torch_model .instance_decoder .head_1 , run_model = lambda model : model (res2_out ), device = None
336+ )
337+ head_args_2 = infer_ttnn_module_args (
338+ model = torch_model .instance_decoder .head_2 , run_model = lambda model : model (res2_out ), device = None
339+ )
340+ if hasattr (parameters .instance_decoder , "head_1" ):
341+ parameters .instance_decoder .head_1 .conv_args = head_args_1
342+ if hasattr (parameters .instance_decoder , "head_2" ):
343+ parameters .instance_decoder .head_2 .conv_args = head_args_2
344+ else :
345+ raise NotImplementedError ("Unknown torch model. Parameter conv args not implemented" )
346+ return parameters
347+
348+
349+ def preprocess_image (
350+ image_path : str , input_width : int , input_height : int , ttnn_device : ttnn .Device , inputs_mesh_mapper : Optional [Any ]
351+ ) -> Tuple [torch .Tensor , ttnn .Tensor , np .ndarray , Tuple [int , int ]]:
352+ """Preprocess image for both PyTorch and TTNN"""
353+ # Load image
354+ image = Image .open (image_path ).convert ("RGB" )
355+ original_size = image .size # (width, height)
356+ original_array = np .array (image )
357+ preprocess = transforms .Compose (
358+ [transforms .ToTensor (), transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])]
359+ )
360+
361+ # Resize to model input size
362+ target_size = (input_width , input_height ) # PIL expects (width, height)
363+ image_resized = image .resize (target_size )
364+
365+ # PyTorch preprocessing
366+ torch_tensor = preprocess (image_resized ).unsqueeze (0 ) # Add batch dimension
367+ torch_tensor = torch_tensor .to (torch .float )
368+
369+ # TTNN preprocessing
370+ ttnn_tensor = None
371+ ttnn_tensor = ttnn .from_torch (
372+ torch_tensor .permute (0 , 2 , 3 , 1 ), # BCHW -> BHWC
373+ dtype = ttnn .bfloat16 ,
374+ device = ttnn_device ,
375+ mesh_mapper = inputs_mesh_mapper ,
376+ )
377+
378+ if ttnn_tensor is not None :
379+ ttnn_as_torch = ttnn .to_torch (ttnn_tensor )
380+
381+ return torch_tensor , ttnn_tensor , original_array , original_size
382+
383+
384+ def save_preprocessed_inputs (torch_input : torch .Tensor , save_dir : str , filename : str ):
385+ """Save preprocessed inputs for testing purposes"""
386+
387+ # Create directory for test inputs
388+ test_inputs_dir = os .path .join (save_dir , "test_inputs" )
389+ os .makedirs (test_inputs_dir , exist_ok = True )
390+
391+ # Save torch input tensor
392+ torch_input_path = os .path .join (test_inputs_dir , f"{ filename } _torch_input.pt" )
393+ torch .save (
394+ {
395+ "tensor" : torch_input ,
396+ "shape" : torch_input .shape ,
397+ "dtype" : torch_input .dtype ,
398+ "mean" : torch_input .mean ().item (),
399+ "std" : torch_input .std ().item (),
400+ "min" : torch_input .min ().item (),
401+ "max" : torch_input .max ().item (),
402+ },
403+ torch_input_path ,
404+ )
405+
406+ logger .info (f"Saved preprocessed torch input to: { torch_input_path } " )
407+
408+ return torch_input_path
0 commit comments