Skip to content

Commit 4ddba24

Browse files
tests update
1 parent 9970595 commit 4ddba24

File tree

8 files changed

+35
-45
lines changed

8 files changed

+35
-45
lines changed

models/experimental/SSR/tests/test_RHAG.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,7 @@ def custom_preprocessor(torch_model, name, ttnn_module_args):
9393
@pytest.mark.parametrize(
9494
"batch_size, height, width, dim, num_heads, window_size, depth, overlap_ratio, mlp_ratio, resi_connection",
9595
[
96-
# (1, 32, 32, 180, 6, 16, 2, 0.5, 2.0, "1conv"), # Standard configuration with conv
97-
(1, 64, 64, 180, 6, 16, 6, 0.5, 2.0, "1conv"), # Standard configuration with conv
98-
# (1, 64, 64, 180, 6, 16, 2, 0.5, 2.0, "1conv"), # Standard configuration with conv
99-
# (1, 32, 32, 96, 3, 8, 3, 0.25, 4.0, "identity"), # Identity connection
100-
# (2, 64, 64, 180, 6, 16, 1, 0.5, 2.0, "1conv"), # Batch size 2
101-
# (1, 128, 128, 192, 6, 16, 2, 0.75, 3.0, "1conv"), # Larger resolution
96+
(1, 64, 64, 180, 6, 16, 6, 0.5, 2, "1conv"), # SSR config
10297
],
10398
)
10499
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)

models/experimental/SSR/tests/test_atten_blocks.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@ def custom_preprocessor(torch_model, name, ttnn_module_args):
3737
@pytest.mark.parametrize(
3838
"batch_size, height, width, dim, num_heads, window_size, depth, overlap_ratio, mlp_ratio",
3939
[
40-
(1, 64, 64, 180, 6, 16, 2, 0.5, 2.0), # Standard configuration
41-
# (1, 32, 32, 96, 3, 8, 3, 0.25, 4.0), # Smaller resolution, more blocks
42-
# (2, 64, 64, 180, 6, 16, 1, 0.5, 2.0), # Batch size 2, single block
43-
# (1, 128, 128, 192, 6, 16, 2, 0.75, 3.0), # Larger resolution
44-
# (2, 64, 64, 180, 6, 16, 6, 0.5, 2), # Network config
40+
(1, 64, 64, 180, 6, 16, 6, 0.5, 2), # SSR config
4541
],
4642
)
4743
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)

models/experimental/SSR/tests/test_patch_embed_tile_selection.py renamed to models/experimental/SSR/tests/test_patch_embed_tile_refinement.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,8 @@ def custom_preprocessor(torch_model, name, ttnn_module_args):
131131
@pytest.mark.parametrize(
132132
"batch_size, img_size, patch_size, in_chans, embed_dim, norm_layer",
133133
[
134-
(1, 224, 16, 3, 768, None), # Standard ViT-Base config
135-
(2, 256, 4, 3, 96, None), # Smaller embedding
136-
(1, 32, 4, 180, 180, None), # Your SSR config
137-
# (1, 64, 8, 3, 192, nn.LayerNorm), # With normalization
138-
(4, 128, 16, 3, 384, None), # Batch size 4
134+
(1, 64, 2, 3, 180, None), # TR blk test
135+
(1, 64, 4, 3, 180, None), # HAT blk test
139136
],
140137
)
141138
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
@@ -200,11 +197,13 @@ def test_patch_embed_simple(device, batch_size, img_size, patch_size, in_chans,
200197
logger.info(pcc_message)
201198

202199
if does_pass:
203-
logger.info("Simple PatchEmbed Passed!")
200+
logger.info("TR PatchEmbed Passed!")
204201
else:
205-
logger.warning("Simple PatchEmbed Failed!")
202+
logger.warning("TR PatchEmbed Failed!")
206203

207204
assert does_pass, f"PCC check failed: {pcc_message}"
208205
assert (
209206
ref_output.shape == tt_torch_output.shape
210207
), f"Shape mismatch: ref {ref_output.shape} vs ttnn {tt_torch_output.shape}"
208+
209+
ttnn.close_device(device)

models/experimental/SSR/tests/test_patch_unembed.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import pytest
22
import torch
33
import ttnn
4-
from tests.ttnn.utils_for_testing import assert_with_pcc
4+
from loguru import logger
5+
from models.utility_functions import comp_pcc
6+
57
from models.utility_functions import torch_random
68

79
from models.experimental.SSR.reference.SSR.model.tile_refinement import PatchUnEmbed
810
from models.experimental.SSR.tt.patch_unembed import TTPatchUnEmbed
911

1012

11-
@pytest.mark.parametrize("batch_size", [1, 2, 8])
12-
@pytest.mark.parametrize("img_size", [64])
13-
@pytest.mark.parametrize("patch_size", [1])
14-
@pytest.mark.parametrize("in_chans", [180])
15-
@pytest.mark.parametrize("embed_dim", [180])
13+
@pytest.mark.parametrize(
14+
"batch_size, img_size, patch_size, in_chans, embed_dim",
15+
[
16+
# (1, 64, 4, 3, 180), # TR blk test
17+
(1, 64, 2, 3, 180), # HAT blk test
18+
],
19+
)
1620
def test_tt_patch_unembed(device, batch_size, img_size, patch_size, in_chans, embed_dim):
1721
"""Test TTPatchUnEmbed against PyTorch reference implementation"""
1822
torch.manual_seed(0)
@@ -37,17 +41,29 @@ def test_tt_patch_unembed(device, batch_size, img_size, patch_size, in_chans, em
3741

3842
# Convert input to TTNN format
3943
ttnn_input = ttnn.from_torch(
40-
torch_input, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
44+
torch_input, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
4145
)
4246

4347
# Run TTNN model
4448
ttnn_output = tt_model(ttnn_input, patches_resolution)
4549
ttnn_output_torch = ttnn.to_torch(ttnn_output)
4650

51+
# Compare outputs
52+
does_pass, pcc_message = comp_pcc(torch_output, ttnn_output_torch, 0.99)
53+
54+
logger.info(f"Reference output shape: {torch_output.shape}")
55+
logger.info(f"TTNN output shape: {ttnn_output_torch.shape}")
56+
logger.info(pcc_message)
57+
58+
if does_pass:
59+
logger.info("TR PatchEmbed Passed!")
60+
else:
61+
logger.warning("TR PatchEmbed Failed!")
62+
63+
assert does_pass, f"PCC check failed: {pcc_message}"
4764
# Assert shapes match
4865
assert (
4966
torch_output.shape == ttnn_output_torch.shape
5067
), f"Shape mismatch: {torch_output.shape} vs {ttnn_output_torch.shape}"
5168

52-
# Assert values match with PCC
53-
assert_with_pcc(torch_output, ttnn_output_torch, 0.999)
69+
ttnn.close_device(device)

models/experimental/SSR/tests/test_tile_refinement.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def custom_preprocessor(torch_model, name, ttnn_module_args):
118118
# Test configuration - adjust based on your requirements
119119
# (64, 1, 96, (2, 2), (6, 6), 7, 4, 2, (1, 3, 64, 64)),
120120
# (64, 1, 180, (2, 2, 2), (6, 6, 6), 16, 2, 4, (1, 3, 64, 64)),
121-
(64, 1, 180, (6, 6, 6, 6, 6, 6), (6, 6, 6, 6, 6, 6), 16, 2, 4, (3, 3, 64, 64)),
121+
(64, 2, 180, (6, 6, 6, 6, 6, 6), (6, 6, 6, 6, 6, 6), 16, 2, 4, (3, 3, 64, 64)),
122122
],
123123
)
124124
def test_tile_refinement(

models/experimental/SSR/tt/atten_blocks.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,8 @@ def __init__(
7373
mlp_ratio=mlp_ratio,
7474
)
7575

76-
# Downsample layer (if provided)
7776
self.downsample = None
7877
if downsample is not None:
79-
# Note: You'll need to implement the downsample layer in TTNN
80-
# This depends on what type of downsampling is used
8178
self.downsample = downsample
8279

8380
def forward(self, x, x_size, params):

models/experimental/SSR/tt/patch_embed_tile_refinement.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ def __init__(
3636
self.patches_resolution = [self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1]]
3737
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
3838

39-
self.in_chans = in_chans
40-
self.embed_dim = embed_dim
41-
self.device = device
4239
self.memory_config = memory_config or ttnn.DRAM_MEMORY_CONFIG
4340

4441
# Store normalization parameters if provided

models/experimental/SSR/tt/patch_unembed.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,6 @@ class TTPatchUnEmbed(LightweightModule):
88
def __init__(self, mesh_device, img_size=224, patch_size=4, in_chans=3, embed_dim=96):
99
super().__init__()
1010

11-
self.mesh_device = mesh_device
12-
13-
# Convert to tuples like in the original
14-
self.img_size = (img_size, img_size)
15-
self.patch_size = (patch_size, patch_size)
16-
17-
self.patches_resolution = [self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1]]
18-
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
19-
20-
self.in_chans = in_chans
2111
self.embed_dim = embed_dim
2212

2313
def forward(self, x, x_size):
@@ -27,6 +17,6 @@ def forward(self, x, x_size):
2717
x = ttnn.permute(x, (0, 2, 1)) # (batch_size, embed_dim, num_patches)
2818

2919
# Reshape to spatial dimensions
30-
x = ttnn.reshape(x, (batch_size, self.embed_dim, x_size[0], x_size[1]))
20+
x = ttnn.reshape(x, (batch_size, self.embed_dim, x_size[0], x_size[1]), memory_config=ttnn.L1_MEMORY_CONFIG)
3121

3222
return x

0 commit comments

Comments
 (0)