Skip to content

Commit fb87485

Browse files
committed
refactored full net test
1 parent d283507 commit fb87485

File tree

1 file changed

+118
-213
lines changed

1 file changed

+118
-213
lines changed

models/experimental/panoptic_deeplab/tests/test_panoptic_deeplab.py

Lines changed: 118 additions & 213 deletions
Original file line numberDiff line numberDiff line change
@@ -5,252 +5,175 @@
55
import torch
66
from loguru import logger
77
import ttnn
8-
from ttnn.model_preprocessing import preprocess_model_parameters
8+
from ttnn.model_preprocessing import (
9+
preprocess_model_parameters,
10+
infer_ttnn_module_args,
11+
)
912
from tests.ttnn.utils_for_testing import check_with_pcc
1013

1114
from models.experimental.panoptic_deeplab.reference.panoptic_deeplab import TorchPanopticDeepLab
1215
from models.experimental.panoptic_deeplab.tt.panoptic_deeplab import TTPanopticDeepLab
1316
from models.experimental.panoptic_deeplab.tt.custom_preprocessing import create_custom_mesh_preprocessor
14-
from ttnn.model_preprocessing import infer_ttnn_module_args, preprocess_model_parameters
15-
from models.experimental.panoptic_deeplab.common import load_torch_model_state
1617

1718

1819
class PanopticDeepLabTestInfra:
19-
def __init__(
20-
self,
21-
device,
22-
batch_size,
23-
in_channels,
24-
height,
25-
width,
26-
model_config,
27-
):
20+
_seeded = False
21+
_PCC_THRESH = 0.97
22+
23+
def __init__(self, device, batch_size, in_channels, height, width, model_config):
2824
super().__init__()
29-
if not hasattr(self, "_model_initialized"):
30-
torch.manual_seed(42)
31-
self._model_initialized = True
32-
torch.cuda.manual_seed_all(42)
33-
torch.backends.cudnn.deterministic = True
25+
self._maybe_seed()
3426

35-
self.pcc_passed = False
36-
self.pcc_message = "call validate()?"
27+
# Core state
3728
self.device = device
29+
self.model_config = model_config
3830
self.num_devices = device.get_num_devices()
39-
self.batch_size = batch_size
40-
self.in_channels = in_channels
41-
self.height = height
42-
self.width = width
31+
self.batch_size, self.in_channels, self.height, self.width = (
32+
batch_size,
33+
in_channels,
34+
height,
35+
width,
36+
)
4337
self.inputs_mesh_mapper, self.weights_mesh_mapper, self.output_mesh_composer = self.get_mesh_mappers(device)
4438

45-
# Initialize torch model
46-
torch_model = TorchPanopticDeepLab()
47-
torch_model = load_torch_model_state(torch_model, "panoptic_deeplab")
48-
49-
# Create input tensor
39+
# Torch reference model + inputs
40+
torch_model = TorchPanopticDeepLab().eval()
5041
input_shape = (batch_size * self.num_devices, in_channels, height, width)
5142
self.torch_input_tensor = torch.rand(input_shape, dtype=torch.float)
5243

53-
# Preprocess model parameters
44+
# Preprocess TTNN parameters
5445
parameters = preprocess_model_parameters(
5546
initialize_model=lambda: torch_model,
5647
custom_preprocessor=create_custom_mesh_preprocessor(self.weights_mesh_mapper),
5748
device=None,
5849
)
5950

60-
parameters.conv_args = {}
61-
input_tensor = torch.randn(1, 2048, 32, 64)
62-
res3_tensor = torch.randn(1, 512, 64, 128)
63-
res2_tensor = torch.randn(1, 256, 128, 256)
51+
# Populate conv_args for decoders via one small warm-up pass
52+
self._populate_all_decoders(torch_model, parameters)
6453

65-
# For semantic decoder
66-
if hasattr(parameters, "semantic_decoder"):
67-
# ASPP
68-
aspp_args = infer_ttnn_module_args(
69-
model=torch_model.semantic_decoder.aspp, run_model=lambda model: model(input_tensor), device=None
70-
)
71-
if hasattr(parameters.semantic_decoder, "aspp"):
72-
parameters.semantic_decoder.aspp.conv_args = aspp_args
73-
74-
# Res3
75-
aspp_out = torch_model.semantic_decoder.aspp(input_tensor)
76-
res3_args = infer_ttnn_module_args(
77-
model=torch_model.semantic_decoder.res3,
78-
run_model=lambda model: model(aspp_out, res3_tensor),
79-
device=None,
80-
)
81-
if hasattr(parameters.semantic_decoder, "res3"):
82-
parameters.semantic_decoder.res3.conv_args = res3_args
83-
84-
# Res2
85-
res3_out = torch_model.semantic_decoder.res3(aspp_out, res3_tensor)
86-
res2_args = infer_ttnn_module_args(
87-
model=torch_model.semantic_decoder.res2,
88-
run_model=lambda model: model(res3_out, res2_tensor),
89-
device=None,
90-
)
91-
if hasattr(parameters.semantic_decoder, "res2"):
92-
parameters.semantic_decoder.res2.conv_args = res2_args
93-
94-
# Head
95-
res2_out = torch_model.semantic_decoder.res2(res3_out, res2_tensor)
96-
head_args = infer_ttnn_module_args(
97-
model=torch_model.semantic_decoder.head_1, run_model=lambda model: model(res2_out), device=None
98-
)
99-
if hasattr(parameters.semantic_decoder, "head_1"):
100-
parameters.semantic_decoder.head_1.conv_args = head_args
101-
102-
# For instance decoder
103-
if hasattr(parameters, "instance_decoder"):
104-
# ASPP
105-
aspp_args = infer_ttnn_module_args(
106-
model=torch_model.instance_decoder.aspp, run_model=lambda model: model(input_tensor), device=None
107-
)
108-
if hasattr(parameters.instance_decoder, "aspp"):
109-
parameters.instance_decoder.aspp.conv_args = aspp_args
110-
111-
# Res3
112-
aspp_out = torch_model.instance_decoder.aspp(input_tensor)
113-
res3_args = infer_ttnn_module_args(
114-
model=torch_model.instance_decoder.res3,
115-
run_model=lambda model: model(aspp_out, res3_tensor),
116-
device=None,
117-
)
118-
if hasattr(parameters.instance_decoder, "res3"):
119-
parameters.instance_decoder.res3.conv_args = res3_args
120-
121-
# Res2
122-
res3_out = torch_model.instance_decoder.res3(aspp_out, res3_tensor)
123-
res2_args = infer_ttnn_module_args(
124-
model=torch_model.instance_decoder.res2,
125-
run_model=lambda model: model(res3_out, res2_tensor),
126-
device=None,
127-
)
128-
if hasattr(parameters.instance_decoder, "res2"):
129-
parameters.instance_decoder.res2.conv_args = res2_args
130-
131-
# Head
132-
res2_out = torch_model.instance_decoder.res2(res3_out, res2_tensor)
133-
head_args_1 = infer_ttnn_module_args(
134-
model=torch_model.instance_decoder.head_1, run_model=lambda model: model(res2_out), device=None
135-
)
136-
head_args_2 = infer_ttnn_module_args(
137-
model=torch_model.instance_decoder.head_2, run_model=lambda model: model(res2_out), device=None
138-
)
139-
if hasattr(parameters.instance_decoder, "head_1"):
140-
parameters.instance_decoder.head_1.conv_args = head_args_1
141-
if hasattr(parameters.instance_decoder, "head_2"):
142-
parameters.instance_decoder.head_2.conv_args = head_args_2
143-
144-
# Run torch model with bfloat16
54+
# Run Torch once (fp32) → then bf16 for parity with TTNN
14555
logger.info("Running PyTorch model...")
14656
self.torch_output_tensor, self.torch_output_tensor_2, self.torch_output_tensor_3 = torch_model(
14757
self.torch_input_tensor
14858
)
14959

150-
# Convert input to TTNN format (NHWC)
60+
# Convert input to TTNN NHWC host tensor
15161
logger.info("Converting input to TTNN format...")
15262
tt_host_tensor = ttnn.from_torch(
15363
self.torch_input_tensor.permute(0, 2, 3, 1),
15464
dtype=ttnn.bfloat16,
15565
mesh_mapper=self.inputs_mesh_mapper,
15666
)
15767

158-
# Initialize TTNN model
68+
# TTNN model
15969
logger.info("Initializing TTNN model...")
160-
print("Initializing TTNN model...")
161-
self.ttnn_model = TTPanopticDeepLab(
162-
parameters=parameters,
163-
model_config=model_config,
164-
)
70+
self.ttnn_model = TTPanopticDeepLab(parameters=parameters, model_config=model_config)
16571

166-
logger.info("Running first TTNN model pass (JIT configuration)...")
167-
# first run configures convs JIT
168-
self.input_tensor = ttnn.to_device(tt_host_tensor, device)
169-
self.run()
170-
self.validate()
72+
# First run configures JIT, second run is optimized
73+
for phase in ("JIT configuration", "optimized"):
74+
logger.info(f"Running TTNN model pass ({phase})...")
75+
self.input_tensor = ttnn.to_device(tt_host_tensor, device)
76+
self.run()
77+
self.validate()
17178

172-
logger.info("Running optimized TTNN model pass...")
173-
# Optimized run
174-
self.input_tensor = ttnn.to_device(tt_host_tensor, device)
175-
self.run()
176-
self.validate()
79+
# --------------------------- Setup & helpers ---------------------------
80+
81+
@classmethod
82+
def _maybe_seed(cls):
83+
if not cls._seeded:
84+
torch.manual_seed(42)
85+
torch.cuda.manual_seed_all(42)
86+
torch.backends.cudnn.deterministic = True
87+
cls._seeded = True
17788

17889
def get_mesh_mappers(self, device):
17990
if device.get_num_devices() != 1:
180-
inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0)
181-
weights_mesh_mapper = None
182-
output_mesh_composer = ttnn.ConcatMeshToTensor(device, dim=0)
183-
else:
184-
inputs_mesh_mapper = None
185-
weights_mesh_mapper = None
186-
output_mesh_composer = None
187-
return inputs_mesh_mapper, weights_mesh_mapper, output_mesh_composer
91+
return (
92+
ttnn.ShardTensorToMesh(device, dim=0), # inputs
93+
None, # weights
94+
ttnn.ConcatMeshToTensor(device, dim=0), # outputs
95+
)
96+
return None, None, None
18897

189-
def run(self):
190-
self.output_tensor, self.output_tensor_2, self.output_tensor_3 = self.ttnn_model(self.input_tensor, self.device)
191-
return self.output_tensor, self.output_tensor_2, self.output_tensor_3
98+
@staticmethod
99+
def _infer_and_set(module, params_holder, attr_name, run_fn):
100+
"""Infer conv args for a TTNN module and set them if present in parameters."""
101+
if hasattr(params_holder, attr_name):
102+
args = infer_ttnn_module_args(model=module, run_model=run_fn, device=None)
103+
getattr(params_holder, attr_name).conv_args = args
192104

193-
def validate(self, output_tensor=None):
194-
output_tensor = self.output_tensor if output_tensor is None else output_tensor
195-
output_tensor = ttnn.to_torch(output_tensor, device=self.device, mesh_composer=self.output_mesh_composer)
196-
expected_shape = self.torch_output_tensor.shape
197-
output_tensor = torch.reshape(
198-
output_tensor, (expected_shape[0], expected_shape[2], expected_shape[3], expected_shape[1])
199-
)
200-
output_tensor = torch.permute(output_tensor, (0, 3, 1, 2))
105+
def _populate_decoder(self, torch_dec, params_dec):
106+
"""Warm up a single decoder (semantic or instance) to populate conv_args."""
107+
if not (torch_dec and params_dec):
108+
return
201109

202-
batch_size = output_tensor.shape[0]
110+
# Synthetic tensors that match typical Panoptic-DeepLab strides
111+
input_tensor = torch.randn(1, 2048, 32, 64)
112+
res3_tensor = torch.randn(1, 512, 64, 128)
113+
res2_tensor = torch.randn(1, 256, 128, 256)
203114

204-
valid_pcc = 0.97
205-
self.pcc_passed, self.pcc_message = check_with_pcc(self.torch_output_tensor, output_tensor, pcc=valid_pcc)
206-
assert self.pcc_passed, logger.error(f"Semantic Segmentation Head PCC check failed: {self.pcc_message}")
207-
logger.info(
208-
f"Panoptic DeepLab - Semantic Segmentation Head: batch_size={self.batch_size}, "
209-
f"act_dtype={model_config['ACTIVATIONS_DTYPE']}, weight_dtype={model_config['WEIGHTS_DTYPE']}, "
210-
f"math_fidelity={model_config['MATH_FIDELITY']}, PCC={self.pcc_message}, shape={self.output_tensor.shape}"
211-
)
115+
# ASPP
116+
self._infer_and_set(torch_dec.aspp, params_dec, "aspp", lambda m: m(input_tensor))
117+
aspp_out = torch_dec.aspp(input_tensor)
212118

213-
# Validate instance segmentation head outputs
214-
output_tensor = self.output_tensor_2
215-
output_tensor = ttnn.to_torch(output_tensor, device=self.device, mesh_composer=self.output_mesh_composer)
216-
expected_shape = self.torch_output_tensor_2.shape
217-
output_tensor = torch.reshape(
218-
output_tensor, (expected_shape[0], expected_shape[2], expected_shape[3], expected_shape[1])
219-
)
220-
output_tensor = torch.permute(output_tensor, (0, 3, 1, 2))
119+
# res3
120+
self._infer_and_set(torch_dec.res3, params_dec, "res3", lambda m: m(aspp_out, res3_tensor))
121+
res3_out = torch_dec.res3(aspp_out, res3_tensor)
221122

222-
batch_size = output_tensor.shape[0]
123+
# res2
124+
self._infer_and_set(torch_dec.res2, params_dec, "res2", lambda m: m(res3_out, res2_tensor))
125+
res2_out = torch_dec.res2(res3_out, res2_tensor)
223126

224-
valid_pcc = 0.97
225-
self.pcc_passed, self.pcc_message = check_with_pcc(self.torch_output_tensor_2, output_tensor, pcc=valid_pcc)
226-
assert self.pcc_passed, logger.error(f"Instance Segmentation Head PCC check failed: {self.pcc_message}")
227-
logger.info(
228-
f"Panoptic DeepLab - Instance Segmentation Offset Head: batch_size={self.batch_size}, "
229-
f"act_dtype={model_config['ACTIVATIONS_DTYPE']}, weight_dtype={model_config['WEIGHTS_DTYPE']}, "
230-
f"math_fidelity={model_config['MATH_FIDELITY']}, PCC={self.pcc_message}, shape={self.output_tensor_2.shape}"
231-
)
127+
# heads (one or two, if present)
128+
if hasattr(torch_dec, "head_1"):
129+
self._infer_and_set(torch_dec.head_1, params_dec, "head_1", lambda m: m(res2_out))
130+
if hasattr(torch_dec, "head_2"):
131+
self._infer_and_set(torch_dec.head_2, params_dec, "head_2", lambda m: m(res2_out))
232132

233-
output_tensor = self.output_tensor_3
234-
output_tensor = ttnn.to_torch(output_tensor, device=self.device, mesh_composer=self.output_mesh_composer)
235-
expected_shape = self.torch_output_tensor_3.shape
236-
output_tensor = torch.reshape(
237-
output_tensor, (expected_shape[0], expected_shape[2], expected_shape[3], expected_shape[1])
238-
)
239-
output_tensor = torch.permute(output_tensor, (0, 3, 1, 2))
133+
def _populate_all_decoders(self, torch_model, parameters):
134+
if hasattr(parameters, "semantic_decoder"):
135+
self._populate_decoder(torch_model.semantic_decoder, parameters.semantic_decoder)
136+
if hasattr(parameters, "instance_decoder"):
137+
self._populate_decoder(torch_model.instance_decoder, parameters.instance_decoder)
240138

241-
batch_size = output_tensor.shape[0]
139+
@staticmethod
140+
def _tt_to_torch_nchw(tt_tensor, device, mesh_composer, expected_shape):
141+
"""Convert TTNN NHWC tensor back to Torch NCHW and reshape to expected batch/shape."""
142+
t = ttnn.to_torch(tt_tensor, device=device, mesh_composer=mesh_composer)
143+
t = torch.reshape(t, (expected_shape[0], expected_shape[2], expected_shape[3], expected_shape[1]))
144+
return torch.permute(t, (0, 3, 1, 2))
242145

243-
valid_pcc = 0.97
244-
self.pcc_passed, self.pcc_message = check_with_pcc(self.torch_output_tensor_3, output_tensor, pcc=valid_pcc)
245-
assert self.pcc_passed, logger.error(f"Instance Segmentation Head 2 PCC check failed: {self.pcc_message}")
246-
logger.info(
247-
f"Panoptic DeepLab - Instance Segmentation Center Head: batch_size={self.batch_size}, "
248-
f"act_dtype={model_config['ACTIVATIONS_DTYPE']}, weight_dtype={model_config['WEIGHTS_DTYPE']}, "
249-
f"math_fidelity={model_config['MATH_FIDELITY']}, PCC={self.pcc_message}, shape={self.output_tensor_3.shape}"
250-
)
146+
# --------------------------- Core runs/validation ---------------------------
251147

252-
return self.pcc_passed, self.pcc_message
148+
def run(self):
149+
self.output_tensor, self.output_tensor_2, self.output_tensor_3 = self.ttnn_model(self.input_tensor, self.device)
150+
return self.output_tensor, self.output_tensor_2, self.output_tensor_3
151+
152+
def validate(self):
153+
"""Validate three heads (semantic, offsets, centers) in a uniform loop."""
154+
checks = [
155+
("Semantic Segmentation Head", self.output_tensor, self.torch_output_tensor),
156+
("Instance Segmentation Offset Head", self.output_tensor_2, self.torch_output_tensor_2),
157+
("Instance Segmentation Center Head", self.output_tensor_3, self.torch_output_tensor_3),
158+
]
159+
160+
for name, tt_out, torch_ref in checks:
161+
out = self._tt_to_torch_nchw(tt_out, self.device, self.output_mesh_composer, torch_ref.shape)
162+
passed, msg = check_with_pcc(torch_ref, out, pcc=self._PCC_THRESH)
163+
assert passed, logger.error(f"{name} PCC check failed: {msg}")
164+
165+
logger.info(
166+
f"Panoptic DeepLab - {name}: batch_size={self.batch_size}, "
167+
f"act_dtype={self.model_config['ACTIVATIONS_DTYPE']}, "
168+
f"weight_dtype={self.model_config['WEIGHTS_DTYPE']}, "
169+
f"math_fidelity={self.model_config['MATH_FIDELITY']}, "
170+
f"PCC={msg}, shape={tt_out.shape}"
171+
)
253172

173+
return True, f"All heads passed PCC ≥ {self._PCC_THRESH}"
174+
175+
176+
# --------------------------- Test config ---------------------------
254177

255178
model_config = {
256179
"MATH_FIDELITY": ttnn.MathFidelity.LoFi,
@@ -260,24 +183,6 @@ def validate(self, output_tensor=None):
260183

261184

262185
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
263-
@pytest.mark.parametrize(
264-
"batch_size, in_channels, height, width",
265-
[
266-
(1, 3, 512, 1024),
267-
],
268-
)
269-
def test_panoptic_deeplab(
270-
device,
271-
batch_size,
272-
in_channels,
273-
height,
274-
width,
275-
):
276-
PanopticDeepLabTestInfra(
277-
device,
278-
batch_size,
279-
in_channels,
280-
height,
281-
width,
282-
model_config,
283-
)
186+
@pytest.mark.parametrize("batch_size, in_channels, height, width", [(1, 3, 512, 1024)])
187+
def test_panoptic_deeplab(device, batch_size, in_channels, height, width):
188+
PanopticDeepLabTestInfra(device, batch_size, in_channels, height, width, model_config)

0 commit comments

Comments
 (0)