Skip to content

Commit 6c5b055

Browse files
committed
uniform test infra
1 parent a2530b0 commit 6c5b055

File tree

6 files changed

+405
-424
lines changed

6 files changed

+405
-424
lines changed
Lines changed: 90 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,134 @@
11
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import pytest
54
import torch
5+
import pytest
66
from loguru import logger
77
import ttnn
8-
from ttnn.model_preprocessing import preprocess_model_parameters
98

9+
from ttnn.model_preprocessing import preprocess_model_parameters
1010
from tests.ttnn.utils_for_testing import check_with_pcc
11-
from models.experimental.panoptic_deeplab.tt.custom_preprocessing import create_custom_mesh_preprocessor
1211
from models.experimental.panoptic_deeplab.reference.aspp import ASPPModel
1312
from models.experimental.panoptic_deeplab.tt.aspp import TTASPP
1413
from models.experimental.panoptic_deeplab.common import load_torch_model_state
14+
from models.experimental.panoptic_deeplab.tt.custom_preprocessing import create_custom_mesh_preprocessor
1515

1616

17-
class AsppTestInfra:
18-
def __init__(
19-
self,
20-
device,
21-
batch_size,
22-
input_channels,
23-
height,
24-
width,
25-
model_config,
26-
name,
27-
):
17+
class ASPPTestInfra:
18+
def __init__(self, device, batch_size, input_channels, height, width, model_config, name):
2819
super().__init__()
2920
if not hasattr(self, "_model_initialized"):
3021
torch.manual_seed(42)
31-
self._model_initialized = True
3222
torch.cuda.manual_seed_all(42)
3323
torch.backends.cudnn.deterministic = True
34-
self.pcc_passed = False
35-
self.pcc_message = "call validate()?"
24+
self._model_initialized = True
25+
26+
# Initialize core config
3627
self.device = device
37-
self.num_devices = device.get_num_devices()
3828
self.batch_size = batch_size
39-
self.inputs_mesh_mapper, self.weights_mesh_mapper, self.output_mesh_composer = self.get_mesh_mappers(device)
29+
self.input_channels = input_channels
30+
self.height = height
31+
self.width = width
32+
self.model_config = model_config
4033
self.name = name
34+
self.num_devices = device.get_num_devices()
35+
36+
# Mesh mappers
37+
self.inputs_mesh_mapper, self.weights_mesh_mapper, self.output_mesh_composer = self.get_mesh_mappers(device)
4138

42-
# torch model
39+
logger.info(f"Initializing ASPP test for module: {name}")
40+
41+
# Torch model
4342
torch_model = ASPPModel()
4443
torch_model = load_torch_model_state(torch_model, name)
4544

45+
# Create synthetic input
46+
self.torch_input_tensor = self._create_input_tensor()
47+
48+
# Run torch model
49+
self.torch_output_tensor = torch_model(self.torch_input_tensor)
50+
51+
# Preprocess model
4652
parameters = preprocess_model_parameters(
4753
initialize_model=lambda: torch_model,
4854
custom_preprocessor=create_custom_mesh_preprocessor(self.weights_mesh_mapper),
4955
device=None,
5056
)
5157

52-
# golden
53-
self.torch_input_tensor = torch.randn((batch_size, input_channels, height, width), dtype=torch.float)
54-
self.torch_output_tensor = torch_model(self.torch_input_tensor)
58+
# Initialize TTNN model
59+
self.ttnn_model = TTASPP(parameters, model_config)
5560

56-
# ttnn
57-
tt_host_tensor = ttnn.from_torch(
58-
self.torch_input_tensor.permute(0, 2, 3, 1),
59-
dtype=ttnn.bfloat8_b,
60-
device=device,
61-
mesh_mapper=self.inputs_mesh_mapper,
62-
)
61+
# Prepare TTNN input
62+
logger.info("Converting input to TTNN tensor...")
6363

64-
self.ttnn_model = TTASPP(parameters, model_config)
65-
self.input_tensor = ttnn.to_layout(tt_host_tensor, ttnn.TILE_LAYOUT)
66-
self.input_tensor = ttnn.to_device(tt_host_tensor, device, memory_config=ttnn.L1_MEMORY_CONFIG)
64+
# Run model and validate
65+
for phase in ("JIT configuration", "optimized"):
66+
logger.info(f"Running TTNN model pass ({phase})...")
67+
68+
# Re-convert input tensor (TTNN may deallocate buffers)
69+
tt_host_tensor = ttnn.from_torch(
70+
self.torch_input_tensor.permute(0, 2, 3, 1),
71+
dtype=ttnn.bfloat8_b,
72+
device=self.device,
73+
mesh_mapper=self.inputs_mesh_mapper,
74+
)
75+
self.input_tensor = ttnn.to_device(tt_host_tensor, self.device, memory_config=ttnn.L1_MEMORY_CONFIG)
76+
77+
# Optional: Re-instantiate model if it's not stateless
78+
self.ttnn_model = TTASPP(parameters, self.model_config)
79+
80+
self.run()
81+
self.validate()
6782

68-
# run and validate
69-
self.run()
70-
self.validate()
83+
def _create_input_tensor(self):
84+
shape = (self.batch_size * self.num_devices, self.input_channels, self.height, self.width)
85+
logger.info(f"Generating synthetic input tensor of shape {shape}")
86+
return torch.randn(shape, dtype=torch.float32)
7187

72-
def get_mesh_mappers(self, device):
88+
@classmethod
89+
def get_mesh_mappers(cls, device):
7390
if device.get_num_devices() != 1:
74-
inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0)
75-
weights_mesh_mapper = None
76-
output_mesh_composer = ttnn.ConcatMeshToTensor(device, dim=0)
77-
else:
78-
inputs_mesh_mapper = None
79-
weights_mesh_mapper = None
80-
output_mesh_composer = None
81-
return inputs_mesh_mapper, weights_mesh_mapper, output_mesh_composer
91+
return (
92+
ttnn.ShardTensorToMesh(device, dim=0), # inputs
93+
None, # weights
94+
ttnn.ConcatMeshToTensor(device, dim=0), # outputs
95+
)
96+
return None, None, None
8297

8398
def run(self):
99+
logger.info("Running TTNN ASPP model...")
84100
self.output_tensor = self.ttnn_model(self.input_tensor, self.device)
85101
return self.output_tensor
86102

87-
def validate(self, output_tensor=None):
88-
tt_output_tensor = self.output_tensor if output_tensor is None else output_tensor
89-
tt_output_tensor_torch = ttnn.to_torch(
90-
tt_output_tensor, device=self.device, mesh_composer=self.output_mesh_composer
103+
def _tt_to_torch_nchw(self, tt_tensor, expected_shape):
104+
torch_tensor = ttnn.to_torch(tt_tensor, device=self.device, mesh_composer=self.output_mesh_composer)
105+
torch_tensor = torch.reshape(
106+
torch_tensor,
107+
(expected_shape[0], expected_shape[2], expected_shape[3], expected_shape[1]),
91108
)
109+
return torch.permute(torch_tensor, (0, 3, 1, 2))
92110

93-
# Deallocate output tesnors
94-
ttnn.deallocate(tt_output_tensor)
111+
def validate(self):
112+
logger.info("Validating TTNN output against PyTorch...")
113+
tt_output_tensor_torch = self._tt_to_torch_nchw(self.output_tensor, self.torch_output_tensor.shape)
95114

96-
expected_shape = self.torch_output_tensor.shape
97-
tt_output_tensor_torch = torch.reshape(
98-
tt_output_tensor_torch, (expected_shape[0], expected_shape[2], expected_shape[3], expected_shape[1])
99-
)
100-
tt_output_tensor_torch = torch.permute(tt_output_tensor_torch, (0, 3, 1, 2))
101-
102-
batch_size = tt_output_tensor_torch.shape[0]
115+
# Deallocate to save memory
116+
ttnn.deallocate(self.output_tensor)
103117

104-
valid_pcc = 0.99
105-
self.pcc_passed, self.pcc_message = check_with_pcc(
106-
self.torch_output_tensor, tt_output_tensor_torch, pcc=valid_pcc
107-
)
118+
pcc_threshold = 0.99
119+
passed, msg = check_with_pcc(self.torch_output_tensor, tt_output_tensor_torch, pcc=pcc_threshold)
120+
assert passed, logger.error(f"ASPP PCC check failed: {msg}")
108121

109-
assert self.pcc_passed, logger.error(f"PCC check failed: {self.pcc_message}")
110122
logger.info(
111-
f"Modular Panoptic DeepLab ASPP layer:{self.name} - batch_size={batch_size}, act_dtype={model_config['ACTIVATIONS_DTYPE']}, weight_dtype={model_config['WEIGHTS_DTYPE']}, math_fidelity={model_config['MATH_FIDELITY']}, PCC={self.pcc_message}"
123+
f"ASPP layer `{self.name}` passed: "
124+
f"batch_size={self.batch_size}, "
125+
f"act_dtype={self.model_config['ACTIVATIONS_DTYPE']}, "
126+
f"weight_dtype={self.model_config['WEIGHTS_DTYPE']}, "
127+
f"math_fidelity={self.model_config['MATH_FIDELITY']}, "
128+
f"PCC={msg}"
112129
)
113130

114-
return self.pcc_passed, self.pcc_message
131+
return True, msg
115132

116133

117134
model_config = {
@@ -130,12 +147,12 @@ def validate(self, output_tensor=None):
130147
)
131148
@pytest.mark.parametrize("name", ["semantic_decoder.aspp", "instance_decoder.aspp"])
132149
def test_aspp(device, batch_size, input_channels, height, width, name):
133-
AsppTestInfra(
134-
device,
135-
batch_size,
136-
input_channels,
137-
height,
138-
width,
139-
model_config,
140-
name,
150+
ASPPTestInfra(
151+
device=device,
152+
batch_size=batch_size,
153+
input_channels=input_channels,
154+
height=height,
155+
width=width,
156+
model_config=model_config,
157+
name=name,
141158
)

0 commit comments

Comments
 (0)