Skip to content

Commit fcdc2a0

Browse files
committed
clean full net test
1 parent fb87485 commit fcdc2a0

File tree

3 files changed

+65
-53
lines changed

3 files changed

+65
-53
lines changed

models/experimental/panoptic_deeplab/reference/panoptic_deeplab.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@ def __init__(
2020
) -> None:
2121
super().__init__()
2222

23-
# self.pixel_std = nn.Parameter(torch.randn((3, 1, 1)))
24-
# self.pixel_mean = nn.Parameter(torch.randn((3, 1, 1)))
25-
# self.register_buffer("pixel_mean", torch.randn(3).view(-1, 1, 1), False)
26-
# self.register_buffer("pixel_std", torch.randn(3).view(-1, 1, 1), False)
27-
# self.register_buffer("adsaf", torch.randn(3).view(-1, 1, 1), False)
28-
# self.register_buffer("yurfdgdf", torch.randn(3).view(-1, 1, 1), False)
29-
3023
# Backbone
3124
self.backbone = ResNet52BackBone()
3225

models/experimental/panoptic_deeplab/tests/test_panoptic_deeplab.py

Lines changed: 63 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,52 @@
55
import torch
66
from loguru import logger
77
import ttnn
8-
from ttnn.model_preprocessing import (
9-
preprocess_model_parameters,
10-
infer_ttnn_module_args,
11-
)
8+
from ttnn.model_preprocessing import preprocess_model_parameters
129
from tests.ttnn.utils_for_testing import check_with_pcc
1310

1411
from models.experimental.panoptic_deeplab.reference.panoptic_deeplab import TorchPanopticDeepLab
1512
from models.experimental.panoptic_deeplab.tt.panoptic_deeplab import TTPanopticDeepLab
1613
from models.experimental.panoptic_deeplab.tt.custom_preprocessing import create_custom_mesh_preprocessor
14+
from ttnn.model_preprocessing import infer_ttnn_module_args, preprocess_model_parameters
15+
from models.experimental.panoptic_deeplab.common import load_torch_model_state
1716

1817

1918
class PanopticDeepLabTestInfra:
20-
_seeded = False
21-
_PCC_THRESH = 0.97
22-
23-
def __init__(self, device, batch_size, in_channels, height, width, model_config):
19+
def __init__(
20+
self,
21+
device,
22+
batch_size,
23+
in_channels,
24+
height,
25+
width,
26+
model_config,
27+
):
2428
super().__init__()
25-
self._maybe_seed()
29+
if not hasattr(self, "_model_initialized"):
30+
torch.manual_seed(42)
31+
self._model_initialized = True
32+
torch.cuda.manual_seed_all(42)
33+
torch.backends.cudnn.deterministic = True
2634

27-
# Core state
35+
self.pcc_passed = False
36+
self.pcc_message = "call validate()?"
2837
self.device = device
29-
self.model_config = model_config
3038
self.num_devices = device.get_num_devices()
31-
self.batch_size, self.in_channels, self.height, self.width = (
32-
batch_size,
33-
in_channels,
34-
height,
35-
width,
36-
)
39+
self.batch_size = batch_size
40+
self.in_channels = in_channels
41+
self.height = height
42+
self.width = width
3743
self.inputs_mesh_mapper, self.weights_mesh_mapper, self.output_mesh_composer = self.get_mesh_mappers(device)
3844

39-
# Torch reference model + inputs
40-
torch_model = TorchPanopticDeepLab().eval()
45+
# Initialize torch model
46+
torch_model = TorchPanopticDeepLab()
47+
torch_model = load_torch_model_state(torch_model, "panoptic_deeplab")
48+
49+
# Create input tensor
4150
input_shape = (batch_size * self.num_devices, in_channels, height, width)
4251
self.torch_input_tensor = torch.rand(input_shape, dtype=torch.float)
4352

44-
# Preprocess TTNN parameters
53+
# Preprocess model parameters
4554
parameters = preprocess_model_parameters(
4655
initialize_model=lambda: torch_model,
4756
custom_preprocessor=create_custom_mesh_preprocessor(self.weights_mesh_mapper),
@@ -51,23 +60,27 @@ def __init__(self, device, batch_size, in_channels, height, width, model_config)
5160
# Populate conv_args for decoders via one small warm-up pass
5261
self._populate_all_decoders(torch_model, parameters)
5362

54-
# Run Torch once (fp32) → then bf16 for parity with TTNN
63+
# Run torch model with bfloat16
5564
logger.info("Running PyTorch model...")
5665
self.torch_output_tensor, self.torch_output_tensor_2, self.torch_output_tensor_3 = torch_model(
5766
self.torch_input_tensor
5867
)
5968

60-
# Convert input to TTNN NHWC host tensor
69+
# Convert input to TTNN format (NHWC)
6170
logger.info("Converting input to TTNN format...")
6271
tt_host_tensor = ttnn.from_torch(
6372
self.torch_input_tensor.permute(0, 2, 3, 1),
6473
dtype=ttnn.bfloat16,
6574
mesh_mapper=self.inputs_mesh_mapper,
6675
)
6776

68-
# TTNN model
77+
# Initialize TTNN model
6978
logger.info("Initializing TTNN model...")
70-
self.ttnn_model = TTPanopticDeepLab(parameters=parameters, model_config=model_config)
79+
print("Initializing TTNN model...")
80+
self.ttnn_model = TTPanopticDeepLab(
81+
parameters=parameters,
82+
model_config=model_config,
83+
)
7184

7285
# First run configures JIT, second run is optimized
7386
for phase in ("JIT configuration", "optimized"):
@@ -76,16 +89,7 @@ def __init__(self, device, batch_size, in_channels, height, width, model_config)
7689
self.run()
7790
self.validate()
7891

79-
# --------------------------- Setup & helpers ---------------------------
80-
8192
@classmethod
82-
def _maybe_seed(cls):
83-
if not cls._seeded:
84-
torch.manual_seed(42)
85-
torch.cuda.manual_seed_all(42)
86-
torch.backends.cudnn.deterministic = True
87-
cls._seeded = True
88-
8993
def get_mesh_mappers(self, device):
9094
if device.get_num_devices() != 1:
9195
return (
@@ -143,8 +147,6 @@ def _tt_to_torch_nchw(tt_tensor, device, mesh_composer, expected_shape):
143147
t = torch.reshape(t, (expected_shape[0], expected_shape[2], expected_shape[3], expected_shape[1]))
144148
return torch.permute(t, (0, 3, 1, 2))
145149

146-
# --------------------------- Core runs/validation ---------------------------
147-
148150
def run(self):
149151
self.output_tensor, self.output_tensor_2, self.output_tensor_3 = self.ttnn_model(self.input_tensor, self.device)
150152
return self.output_tensor, self.output_tensor_2, self.output_tensor_3
@@ -157,24 +159,24 @@ def validate(self):
157159
("Instance Segmentation Center Head", self.output_tensor_3, self.torch_output_tensor_3),
158160
]
159161

162+
self._PCC_THRESH = 0.97
163+
160164
for name, tt_out, torch_ref in checks:
161165
out = self._tt_to_torch_nchw(tt_out, self.device, self.output_mesh_composer, torch_ref.shape)
162166
passed, msg = check_with_pcc(torch_ref, out, pcc=self._PCC_THRESH)
163167
assert passed, logger.error(f"{name} PCC check failed: {msg}")
164168

165169
logger.info(
166170
f"Panoptic DeepLab - {name}: batch_size={self.batch_size}, "
167-
f"act_dtype={self.model_config['ACTIVATIONS_DTYPE']}, "
168-
f"weight_dtype={self.model_config['WEIGHTS_DTYPE']}, "
169-
f"math_fidelity={self.model_config['MATH_FIDELITY']}, "
171+
f"act_dtype={model_config['ACTIVATIONS_DTYPE']}, "
172+
f"weight_dtype={model_config['WEIGHTS_DTYPE']}, "
173+
f"math_fidelity={model_config['MATH_FIDELITY']}, "
170174
f"PCC={msg}, shape={tt_out.shape}"
171175
)
172176

173177
return True, f"All heads passed PCC ≥ {self._PCC_THRESH}"
174178

175179

176-
# --------------------------- Test config ---------------------------
177-
178180
model_config = {
179181
"MATH_FIDELITY": ttnn.MathFidelity.LoFi,
180182
"WEIGHTS_DTYPE": ttnn.bfloat8_b,
@@ -183,6 +185,24 @@ def validate(self):
183185

184186

185187
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
186-
@pytest.mark.parametrize("batch_size, in_channels, height, width", [(1, 3, 512, 1024)])
187-
def test_panoptic_deeplab(device, batch_size, in_channels, height, width):
188-
PanopticDeepLabTestInfra(device, batch_size, in_channels, height, width, model_config)
188+
@pytest.mark.parametrize(
189+
"batch_size, in_channels, height, width",
190+
[
191+
(1, 3, 512, 1024),
192+
],
193+
)
194+
def test_panoptic_deeplab(
195+
device,
196+
batch_size,
197+
in_channels,
198+
height,
199+
width,
200+
):
201+
PanopticDeepLabTestInfra(
202+
device,
203+
batch_size,
204+
in_channels,
205+
height,
206+
width,
207+
model_config,
208+
)

models/experimental/panoptic_deeplab/tt/panoptic_deeplab.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import ttnn
5-
from typing import Dict
65

76
from models.experimental.panoptic_deeplab.tt.backbone import TTBackbone
87
from models.experimental.panoptic_deeplab.tt.decoder import TTDecoder, decoder_layer_optimisations
@@ -42,9 +41,9 @@ def __init__(
4241

4342
def __call__(
4443
self,
45-
x: ttnn.Tensor,
44+
x,
4645
device,
47-
) -> Dict[str, ttnn.Tensor]:
46+
):
4847
"""
4948
Forward pass of TTNN Panoptic DeepLab.
5049

0 commit comments

Comments
 (0)