Skip to content

Commit a2530b0

Browse files
committed
Demo added
1 parent 04e47b4 commit a2530b0

File tree

8 files changed

+1256
-9
lines changed

8 files changed

+1256
-9
lines changed

models/experimental/panoptic_deeplab/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ pytest models/experimental/panoptic_deeplab/tests/test_panoptic_deeplab.py
2424
```
2525

2626
### Demo
27+
```
28+
python models/experimental/panoptic_deeplab/demo/panoptic_deeplab_demo.py --input <input image path> --output <output image to be stored path>
29+
```
2730
**Note:** Output images will be saved in the `panoptic_deeplab_predictions/` folder.
2831

2932
#### Single Device (BS=1):

models/experimental/panoptic_deeplab/common.py

Lines changed: 165 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
import pickle
99
import numpy as np
1010
import os
11+
from PIL import Image
12+
from typing import Tuple
13+
import torchvision.transforms as transforms
14+
from typing import Optional, Any
15+
import ttnn
1116
from models.experimental.panoptic_deeplab.reference.resnet52_backbone import ResNet52BackBone as TorchBackbone
1217
from models.experimental.panoptic_deeplab.reference.resnet52_stem import DeepLabStem
1318
from torchvision.models.resnet import Bottleneck
@@ -190,12 +195,12 @@ def load_torch_model_state(torch_model: torch.nn.Module = None, layer_name: str
190195
model_path = model_location_generator("vision-models/panoptic_deeplab", model_subdir="", download_if_ci_v2=True)
191196
if model_path == "models":
192197
if not os.path.exists(
193-
"models/experimental/panoptic_deeplab/reference/Panoptic_Deeplab_R52.pkl"
198+
"models/experimental/panoptic_deeplab/resources/Panoptic_Deeplab_R52.pkl"
194199
): # check if Panoptic_Deeplab_R52.pkl is available
195200
os.system(
196-
"models/experimental/panoptic_deeplab/reference/panoptic_deeplab_weights_download.sh"
201+
"models/experimental/panoptic_deeplab/resources/panoptic_deeplab_weights_download.sh"
197202
) # execute the panoptic_deeplab_weights_download.sh file
198-
weights_path = "models/experimental/panoptic_deeplab/reference/Panoptic_Deeplab_R52.pkl"
203+
weights_path = "models/experimental/panoptic_deeplab/resources/Panoptic_Deeplab_R52.pkl"
199204
else:
200205
weights_path = os.path.join(model_path, "Panoptic_Deeplab_R52.pkl")
201206

@@ -209,7 +214,6 @@ def load_torch_model_state(torch_model: torch.nn.Module = None, layer_name: str
209214
if isinstance(v, np.ndarray) or isinstance(v, np.array):
210215
state_dict[k] = torch.from_numpy(v)
211216
converted_count += 1
212-
logger.debug(f"Converted {converted_count} numpy arrays to torch tensors")
213217

214218
# Get keys
215219
checkpoint_keys = set(state_dict.keys())
@@ -225,6 +229,9 @@ def load_torch_model_state(torch_model: torch.nn.Module = None, layer_name: str
225229
mapped_state_dict = {}
226230
for checkpoint_key, model_key in key_mapping.items():
227231
mapped_state_dict[model_key] = state_dict[checkpoint_key]
232+
del mapped_state_dict["pixel_mean"]
233+
del mapped_state_dict["pixel_std"]
234+
logger.debug(f"Mapped {len(mapped_state_dict)} weights")
228235

229236
if isinstance(
230237
torch_model,
@@ -240,10 +247,162 @@ def load_torch_model_state(torch_model: torch.nn.Module = None, layer_name: str
240247
):
241248
torch_model = load_partial_state(torch_model, mapped_state_dict, layer_name)
242249
elif isinstance(torch_model, TorchPanopticDeepLab):
243-
del mapped_state_dict["pixel_mean"]
244-
del mapped_state_dict["pixel_std"]
245250
torch_model.load_state_dict(mapped_state_dict, strict=True)
246251
else:
247252
raise NotImplementedError("Unknown torch model. Weight loading not implemented")
248253

249254
return torch_model.eval()
255+
256+
257+
def parameter_conv_args(torch_model: torch.nn.Module = None, parameters: dict = None):
258+
from ttnn.model_preprocessing import infer_ttnn_module_args
259+
260+
if isinstance(torch_model, TorchPanopticDeepLab):
261+
parameters.conv_args = {}
262+
sample_x = torch.randn(1, 2048, 32, 64)
263+
sample_res3 = torch.randn(1, 512, 64, 128)
264+
sample_res2 = torch.randn(1, 256, 128, 256)
265+
266+
# For semantic decoder
267+
if hasattr(parameters, "semantic_decoder"):
268+
# ASPP
269+
aspp_args = infer_ttnn_module_args(
270+
model=torch_model.semantic_decoder.aspp, run_model=lambda model: model(sample_x), device=None
271+
)
272+
if hasattr(parameters.semantic_decoder, "aspp"):
273+
parameters.semantic_decoder.aspp.conv_args = aspp_args
274+
275+
# Res3
276+
aspp_out = torch_model.semantic_decoder.aspp(sample_x)
277+
res3_args = infer_ttnn_module_args(
278+
model=torch_model.semantic_decoder.res3,
279+
run_model=lambda model: model(aspp_out, sample_res3),
280+
device=None,
281+
)
282+
if hasattr(parameters.semantic_decoder, "res3"):
283+
parameters.semantic_decoder.res3.conv_args = res3_args
284+
285+
# Res2
286+
res3_out = torch_model.semantic_decoder.res3(aspp_out, sample_res3)
287+
res2_args = infer_ttnn_module_args(
288+
model=torch_model.semantic_decoder.res2,
289+
run_model=lambda model: model(res3_out, sample_res2),
290+
device=None,
291+
)
292+
if hasattr(parameters.semantic_decoder, "res2"):
293+
parameters.semantic_decoder.res2.conv_args = res2_args
294+
295+
# Head
296+
res2_out = torch_model.semantic_decoder.res2(res3_out, sample_res2)
297+
head_args = infer_ttnn_module_args(
298+
model=torch_model.semantic_decoder.head_1, run_model=lambda model: model(res2_out), device=None
299+
)
300+
if hasattr(parameters.semantic_decoder, "head_1"):
301+
parameters.semantic_decoder.head_1.conv_args = head_args
302+
303+
# For instance decoder
304+
if hasattr(parameters, "instance_decoder"):
305+
# ASPP
306+
aspp_args = infer_ttnn_module_args(
307+
model=torch_model.instance_decoder.aspp, run_model=lambda model: model(sample_x), device=None
308+
)
309+
if hasattr(parameters.instance_decoder, "aspp"):
310+
parameters.instance_decoder.aspp.conv_args = aspp_args
311+
312+
# Res3
313+
aspp_out = torch_model.instance_decoder.aspp(sample_x)
314+
res3_args = infer_ttnn_module_args(
315+
model=torch_model.instance_decoder.res3,
316+
run_model=lambda model: model(aspp_out, sample_res3),
317+
device=None,
318+
)
319+
if hasattr(parameters.instance_decoder, "res3"):
320+
parameters.instance_decoder.res3.conv_args = res3_args
321+
322+
# Res2
323+
res3_out = torch_model.instance_decoder.res3(aspp_out, sample_res3)
324+
res2_args = infer_ttnn_module_args(
325+
model=torch_model.instance_decoder.res2,
326+
run_model=lambda model: model(res3_out, sample_res2),
327+
device=None,
328+
)
329+
if hasattr(parameters.instance_decoder, "res2"):
330+
parameters.instance_decoder.res2.conv_args = res2_args
331+
332+
# Head
333+
res2_out = torch_model.instance_decoder.res2(res3_out, sample_res2)
334+
head_args_1 = infer_ttnn_module_args(
335+
model=torch_model.instance_decoder.head_1, run_model=lambda model: model(res2_out), device=None
336+
)
337+
head_args_2 = infer_ttnn_module_args(
338+
model=torch_model.instance_decoder.head_2, run_model=lambda model: model(res2_out), device=None
339+
)
340+
if hasattr(parameters.instance_decoder, "head_1"):
341+
parameters.instance_decoder.head_1.conv_args = head_args_1
342+
if hasattr(parameters.instance_decoder, "head_2"):
343+
parameters.instance_decoder.head_2.conv_args = head_args_2
344+
else:
345+
raise NotImplementedError("Unknown torch model. Parameter conv args not implemented")
346+
return parameters
347+
348+
349+
def preprocess_image(
350+
image_path: str, input_width: int, input_height: int, ttnn_device: ttnn.Device, inputs_mesh_mapper: Optional[Any]
351+
) -> Tuple[torch.Tensor, ttnn.Tensor, np.ndarray, Tuple[int, int]]:
352+
"""Preprocess image for both PyTorch and TTNN"""
353+
# Load image
354+
image = Image.open(image_path).convert("RGB")
355+
original_size = image.size # (width, height)
356+
original_array = np.array(image)
357+
preprocess = transforms.Compose(
358+
[transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
359+
)
360+
361+
# Resize to model input size
362+
target_size = (input_width, input_height) # PIL expects (width, height)
363+
image_resized = image.resize(target_size)
364+
365+
# PyTorch preprocessing
366+
torch_tensor = preprocess(image_resized).unsqueeze(0) # Add batch dimension
367+
torch_tensor = torch_tensor.to(torch.float)
368+
369+
# TTNN preprocessing
370+
ttnn_tensor = None
371+
ttnn_tensor = ttnn.from_torch(
372+
torch_tensor.permute(0, 2, 3, 1), # BCHW -> BHWC
373+
dtype=ttnn.bfloat16,
374+
device=ttnn_device,
375+
mesh_mapper=inputs_mesh_mapper,
376+
)
377+
378+
if ttnn_tensor is not None:
379+
ttnn_as_torch = ttnn.to_torch(ttnn_tensor)
380+
381+
return torch_tensor, ttnn_tensor, original_array, original_size
382+
383+
384+
def save_preprocessed_inputs(torch_input: torch.Tensor, save_dir: str, filename: str):
385+
"""Save preprocessed inputs for testing purposes"""
386+
387+
# Create directory for test inputs
388+
test_inputs_dir = os.path.join(save_dir, "test_inputs")
389+
os.makedirs(test_inputs_dir, exist_ok=True)
390+
391+
# Save torch input tensor
392+
torch_input_path = os.path.join(test_inputs_dir, f"{filename}_torch_input.pt")
393+
torch.save(
394+
{
395+
"tensor": torch_input,
396+
"shape": torch_input.shape,
397+
"dtype": torch_input.dtype,
398+
"mean": torch_input.mean().item(),
399+
"std": torch_input.std().item(),
400+
"min": torch_input.min().item(),
401+
"max": torch_input.max().item(),
402+
},
403+
torch_input_path,
404+
)
405+
406+
logger.info(f"Saved preprocessed torch input to: {torch_input_path}")
407+
408+
return torch_input_path
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from dataclasses import dataclass
6+
from typing import List, Optional
7+
import numpy as np
8+
9+
10+
@dataclass
11+
class DemoConfig:
12+
"""Configuration class for demo parameters"""
13+
14+
# Model configuration
15+
model_type: str = "PanopticDeepLab"
16+
backbone: str = "ResNet-52"
17+
num_classes: int = 19
18+
weights_path: Optional[str] = None
19+
20+
# Input configuration
21+
input_height: int = 512
22+
input_width: int = 1024
23+
crop_enabled: bool = False
24+
normalize_enabled: bool = True
25+
mean: List[float] = None
26+
std: List[float] = None
27+
28+
# Inference configuration
29+
center_threshold: float = 0.1
30+
nms_kernel: int = 7
31+
top_k_instances: int = 200
32+
stuff_area_threshold: int = 4096
33+
34+
# Device configuration
35+
device_id: int = 0
36+
math_fidelity: str = "LoFi"
37+
weights_dtype: str = "bfloat8_b"
38+
activations_dtype: str = "bfloat8_b"
39+
40+
# Output configuration
41+
save_semantic: bool = True
42+
save_instance: bool = True
43+
save_panoptic: bool = True
44+
save_visualization: bool = True
45+
save_comparison: bool = True
46+
47+
# Pipeline configuration
48+
compare_outputs: bool = True
49+
pcc_threshold: float = 0.97
50+
51+
# Dataset configuration (Cityscapes default)
52+
thing_classes: List[int] = None
53+
stuff_classes: List[int] = None
54+
class_names: List[str] = None
55+
56+
def __post_init__(self):
57+
"""Initialize default values after dataclass creation"""
58+
if self.mean is None:
59+
self.mean = [0.485, 0.456, 0.406]
60+
if self.std is None:
61+
self.std = [0.229, 0.224, 0.225]
62+
if self.thing_classes is None:
63+
self.thing_classes = [11, 12, 13, 14, 15, 16, 17, 18] # Cityscapes things
64+
if self.stuff_classes is None:
65+
self.stuff_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # Cityscapes stuff
66+
if self.class_names is None:
67+
self.class_names = [
68+
"road",
69+
"sidewalk",
70+
"building",
71+
"wall",
72+
"fence",
73+
"pole",
74+
"traffic_light",
75+
"traffic_sign",
76+
"vegetation",
77+
"terrain",
78+
"sky",
79+
"person",
80+
"rider",
81+
"car",
82+
"truck",
83+
"bus",
84+
"train",
85+
"motorcycle",
86+
"bicycle",
87+
]
88+
89+
def _get_cityscapes_colors(self) -> np.ndarray:
90+
"""Get Cityscapes color palette"""
91+
return np.array(
92+
[
93+
[128, 64, 128], # road
94+
[244, 35, 232], # sidewalk
95+
[70, 70, 70], # building
96+
[102, 102, 156], # wall
97+
[190, 153, 153], # fence
98+
[153, 153, 153], # pole
99+
[250, 170, 30], # traffic light
100+
[220, 220, 0], # traffic sign
101+
[107, 142, 35], # vegetation
102+
[152, 251, 152], # terrain
103+
[70, 130, 180], # sky
104+
[220, 20, 60], # person
105+
[255, 0, 0], # rider
106+
[0, 0, 142], # car
107+
[0, 0, 70], # truck
108+
[0, 60, 100], # bus
109+
[0, 80, 100], # train
110+
[0, 0, 230], # motorcycle
111+
[119, 11, 32], # bicycle
112+
]
113+
)

0 commit comments

Comments
 (0)