Skip to content

Commit 6fc7511

Browse files
tests refactor
1 parent bcc8d02 commit 6fc7511

29 files changed

+3624
-2
lines changed

models/experimental/SSR/tests/common/test_mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from ttnn.model_preprocessing import preprocess_model_parameters, preprocess_linear_bias, preprocess_linear_weight
1212
from models.utility_functions import (
1313
tt2torch_tensor,
14-
comp_pcc,
1514
)
15+
from tests.ttnn.utils_for_testing import check_with_pcc
1616

1717

1818
def create_mlp_preprocessor(device):
@@ -74,7 +74,7 @@ def test_mlp(device, in_features, hidden_features, out_features, input_shape):
7474
tt_output = tt_layer(tt_input)
7575
tt_torch_output = tt2torch_tensor(tt_output)
7676

77-
does_pass, pcc_message = comp_pcc(ref_output, tt_torch_output, 0.99)
77+
does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.99)
7878

7979
logger.info(pcc_message)
8080

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
import torch
6+
import torch.nn as nn
7+
import ttnn
8+
from loguru import logger
9+
10+
from models.experimental.SSR.tt.tile_refinement import TTCAB
11+
from tests.ttnn.utils_for_testing import check_with_pcc
12+
from models.experimental.SSR.reference.SSR.model.tile_refinement import ChannelAttention
13+
from models.experimental.SSR.tests.tile_refinement.test_channel_attention import create_channel_attention_preprocessor
14+
15+
16+
class CAB(nn.Module):
17+
"""Reference PyTorch CAB implementation"""
18+
19+
def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
20+
super(CAB, self).__init__()
21+
self.cab = nn.Sequential(
22+
nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
23+
nn.GELU(),
24+
nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
25+
ChannelAttention(num_feat, squeeze_factor),
26+
)
27+
28+
def forward(self, x):
29+
return self.cab(x)
30+
31+
32+
def create_cab_preprocessor(device):
33+
def custom_preprocessor(torch_model, name, ttnn_module_args):
34+
params = {}
35+
36+
# Extract the sequential layers from CAB
37+
cab_layers = list(torch_model.cab.children())
38+
conv1 = cab_layers[0] # First Conv2d layer
39+
conv2 = cab_layers[2] # Second Conv2d layer (after GELU)
40+
channel_attention = cab_layers[3] # ChannelAttention module
41+
42+
# Preprocess first convolution (3x3)
43+
params["conv1"] = {
44+
"weight": ttnn.from_torch(conv1.weight, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT),
45+
"bias": ttnn.from_torch(conv1.bias.reshape(1, 1, 1, -1), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT),
46+
}
47+
48+
# Preprocess second convolution (3x3)
49+
params["conv2"] = {
50+
"weight": ttnn.from_torch(conv2.weight, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT),
51+
"bias": ttnn.from_torch(conv2.bias.reshape(1, 1, 1, -1), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT),
52+
}
53+
54+
# Preprocess channel attention using existing preprocessor
55+
channel_attention_preprocessor = create_channel_attention_preprocessor(device)
56+
params["channel_attention"] = channel_attention_preprocessor(
57+
channel_attention, "channel_attention", ttnn_module_args
58+
)
59+
60+
return params
61+
62+
return custom_preprocessor
63+
64+
65+
@pytest.mark.parametrize(
66+
"batch_size, num_feat, height, width, compress_ratio, squeeze_factor",
67+
[
68+
(1, 180, 64, 64, 3, 30), # SSR config
69+
],
70+
)
71+
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
72+
def test_cab_block(device, batch_size, num_feat, height, width, compress_ratio, squeeze_factor):
73+
torch.manual_seed(0)
74+
75+
# Create reference model
76+
ref_model = CAB(num_feat=num_feat, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
77+
ref_model.eval()
78+
79+
# Create input tensor
80+
input_tensor = torch.randn(batch_size, num_feat, height, width)
81+
82+
# Reference forward pass
83+
with torch.no_grad():
84+
ref_output = ref_model(input_tensor)
85+
86+
parameters = ttnn.model_preprocessing.preprocess_model(
87+
initialize_model=lambda: ref_model,
88+
custom_preprocessor=create_cab_preprocessor(device),
89+
device=device,
90+
run_model=lambda model: model(input_tensor),
91+
)
92+
93+
tt_model = TTCAB(
94+
device=device,
95+
parameters=parameters,
96+
num_feat=num_feat,
97+
compress_ratio=compress_ratio,
98+
squeeze_factor=squeeze_factor,
99+
memory_config=ttnn.DRAM_MEMORY_CONFIG,
100+
)
101+
102+
# Convert input to TTNN format (NHWC)
103+
tt_input = ttnn.from_torch(
104+
input_tensor.permute(0, 2, 3, 1),
105+
device=device,
106+
layout=ttnn.TILE_LAYOUT,
107+
dtype=ttnn.bfloat16,
108+
memory_config=ttnn.L1_MEMORY_CONFIG, # NCHW -> NHWC
109+
)
110+
111+
# TTNN forward pass
112+
tt_output = tt_model(tt_input)
113+
114+
# Convert back to PyTorch format
115+
tt_torch_output = ttnn.to_torch(tt_output)
116+
tt_torch_output = tt_torch_output.permute(0, 3, 1, 2) # NHWC -> NCHW
117+
118+
# Compare outputs
119+
does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.97)
120+
121+
logger.info(f"Batch: {batch_size}, Features: {num_feat}, Size: {height}x{width}")
122+
logger.info(f"Compress ratio: {compress_ratio}, Squeeze factor: {squeeze_factor}")
123+
logger.info(f"Reference output shape: {ref_output.shape}")
124+
logger.info(f"TTNN output shape: {tt_torch_output.shape}")
125+
logger.info(pcc_message)
126+
127+
if does_pass:
128+
logger.info("CAB Block Passed!")
129+
else:
130+
logger.warning("CAB Block Failed!")
131+
132+
assert does_pass, f"PCC check failed: {pcc_message}"
133+
134+
# Verify output shapes match
135+
assert (
136+
ref_output.shape == tt_torch_output.shape
137+
), f"Shape mismatch: ref {ref_output.shape} vs ttnn {tt_torch_output.shape}"

0 commit comments

Comments
 (0)