Skip to content

Commit 2046637

Browse files
committed
Adds license comments
1 parent ab29083 commit 2046637

36 files changed

+160
-171
lines changed

models/experimental/SSR/demo/ssr_demo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
import os
25
import torch
36
import ttnn

models/experimental/SSR/tests/common/test_mlp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
import torch
25
import pytest
36

@@ -83,4 +86,4 @@ def test_mlp(device, in_features, hidden_features, out_features, input_shape):
8386
else:
8487
logger.warning("SSR MLP Failed!")
8588

86-
assert does_pass
89+
assert does_pass, f"PCC check failed: {pcc_message}"

models/experimental/SSR/tests/test_ssr.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
import torch
25
import pytest
36
import ttnn
@@ -12,10 +15,8 @@
1215
from models.experimental.SSR.tests.tile_refinement.test_HAB import create_relative_position_index
1316

1417
from ttnn.model_preprocessing import preprocess_model_parameters
15-
from models.utility_functions import (
16-
tt2torch_tensor,
17-
comp_pcc,
18-
)
18+
from models.utility_functions import tt2torch_tensor
19+
from tests.ttnn.utils_for_testing import check_with_pcc
1920

2021

2122
def create_ssr_preprocessor(device, args, num_cls):
@@ -166,11 +167,8 @@ def test_ssr_model(input_shape, num_cls, with_conv):
166167
tt_torch_sr = tt_torch_sr.permute(0, 3, 1, 2)
167168

168169
# Compare outputs
169-
sr_pass, sr_pcc_message = comp_pcc(ref_sr, tt_torch_sr, 0.95)
170-
fea3_pass, fea3_pcc_message = comp_pcc(ref_patch_fea3, tt_torch_patch_fea3, 0.95)
171-
172-
logger.info(f"SR Output PCC: {sr_pcc_message}")
173-
logger.info(f"Patch Fea3 PCC: {fea3_pcc_message}")
170+
sr_pass, sr_pcc_message = check_with_pcc(ref_sr, tt_torch_sr, 0.95)
171+
fea3_pass, fea3_pcc_message = check_with_pcc(ref_patch_fea3, tt_torch_patch_fea3, 0.95)
174172

175173
all_pass = sr_pass and fea3_pass
176174

@@ -179,8 +177,8 @@ def test_ssr_model(input_shape, num_cls, with_conv):
179177
else:
180178
logger.warning("TTSSR Test Failed!")
181179

182-
assert sr_pass, f"SR output comparison failed: {sr_pcc_message}"
183-
assert fea3_pass, f"Patch fea3 comparison failed: {fea3_pcc_message}"
180+
assert sr_pass, f"SR output failed PCC check: {sr_pcc_message}"
181+
assert fea3_pass, f"Patch fea3 failed PCC check: {fea3_pcc_message}"
184182

185183
finally:
186184
ttnn.close_device(device)

models/experimental/SSR/tests/tile_refinement/test_CAB.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -118,20 +118,9 @@ def test_cab_block(device, batch_size, num_feat, height, width, compress_ratio,
118118
# Compare outputs
119119
does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.97)
120120

121-
logger.info(f"Batch: {batch_size}, Features: {num_feat}, Size: {height}x{width}")
122-
logger.info(f"Compress ratio: {compress_ratio}, Squeeze factor: {squeeze_factor}")
123-
logger.info(f"Reference output shape: {ref_output.shape}")
124-
logger.info(f"TTNN output shape: {tt_torch_output.shape}")
125-
logger.info(pcc_message)
126-
127121
if does_pass:
128122
logger.info("CAB Block Passed!")
129123
else:
130124
logger.warning("CAB Block Failed!")
131125

132126
assert does_pass, f"PCC check failed: {pcc_message}"
133-
134-
# Verify output shapes match
135-
assert (
136-
ref_output.shape == tt_torch_output.shape
137-
), f"Shape mismatch: ref {ref_output.shape} vs ttnn {tt_torch_output.shape}"

models/experimental/SSR/tests/tile_refinement/test_HAB.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -313,19 +313,9 @@ def test_hab_block(device, batch_size, height, width, dim, num_heads, window_siz
313313
# Compare outputs
314314
does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.95)
315315

316-
logger.info(f"Batch: {batch_size}, Size: {height}x{width}, Dim: {dim}")
317-
logger.info(f"Heads: {num_heads}, Window: {window_size}, Shift: {shift_size}")
318-
logger.info(f"Reference output shape: {ref_output.shape}")
319-
logger.info(f"TTNN output shape: {tt_torch_output.shape}")
320-
logger.info(f"Actual Run: {actual_run_time} s")
321-
logger.info(pcc_message)
322-
323316
if does_pass:
324317
logger.info("HAB Block Passed!")
325318
else:
326319
logger.warning("HAB Block Failed!")
327320

328321
assert does_pass, f"PCC check failed: {pcc_message}"
329-
assert (
330-
ref_output.shape == tt_torch_output.shape
331-
), f"Shape mismatch: ref {ref_output.shape} vs ttnn {tt_torch_output.shape}"

models/experimental/SSR/tests/tile_refinement/test_OCAB.py

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
import pytest
25
import torch
36
import torch.nn as nn
@@ -97,7 +100,8 @@ def custom_preprocessor(model, name):
97100
"dim, input_resolution, window_size, overlap_ratio, num_heads, input_shape",
98101
((180, (64, 64), 16, 0.5, 6, (1, 4096, 180)),),
99102
)
100-
def test_ocab(dim, input_resolution, window_size, overlap_ratio, num_heads, input_shape):
103+
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
104+
def test_ocab(device, dim, input_resolution, window_size, overlap_ratio, num_heads, input_shape):
101105
x = torch.randn(input_shape)
102106

103107
# Create reference OCAB layer
@@ -121,41 +125,36 @@ def test_ocab(dim, input_resolution, window_size, overlap_ratio, num_heads, inpu
121125

122126
ref_output = ref_layer(x, x_size, rpi)
123127

124-
device = ttnn.open_device(device_id=0, l1_small_size=32768)
125128
ttnn.synchronize_device(device)
126129

127-
try:
128-
parameters = preprocess_model_parameters(
129-
initialize_model=lambda: ref_layer, custom_preprocessor=create_ocab_preprocessor(device), device=device
130-
)
131-
132-
tt_layer = TTOCAB(
133-
device=device,
134-
dim=dim,
135-
input_resolution=input_resolution,
136-
window_size=window_size,
137-
overlap_ratio=overlap_ratio,
138-
num_heads=num_heads,
139-
parameters=parameters,
140-
)
141-
142-
tt_input = ttnn.from_torch(
143-
x, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG
144-
)
145-
tt_rpi = ttnn.from_torch(rpi, device=device, layout=ttnn.TILE_LAYOUT)
146-
tt_output = tt_layer.forward(tt_input, x_size, tt_rpi)
147-
tt_torch_output = tt2torch_tensor(tt_output)
148-
149-
does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.99)
150-
151-
logger.info(pcc_message)
152-
153-
if does_pass:
154-
logger.info("OCAB Layer Passed!")
155-
else:
156-
logger.warning("OCAB Layer Failed!")
157-
158-
finally:
159-
ttnn.close_device(device)
160-
161-
assert does_pass
130+
parameters = preprocess_model_parameters(
131+
initialize_model=lambda: ref_layer, custom_preprocessor=create_ocab_preprocessor(device), device=device
132+
)
133+
134+
tt_layer = TTOCAB(
135+
device=device,
136+
dim=dim,
137+
input_resolution=input_resolution,
138+
window_size=window_size,
139+
overlap_ratio=overlap_ratio,
140+
num_heads=num_heads,
141+
parameters=parameters,
142+
)
143+
144+
tt_input = ttnn.from_torch(
145+
x, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG
146+
)
147+
tt_rpi = ttnn.from_torch(rpi, device=device, layout=ttnn.TILE_LAYOUT)
148+
tt_output = tt_layer.forward(tt_input, x_size, tt_rpi)
149+
tt_torch_output = tt2torch_tensor(tt_output)
150+
151+
does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.99)
152+
153+
logger.info(pcc_message)
154+
155+
if does_pass:
156+
logger.info("OCAB Layer Passed!")
157+
else:
158+
logger.warning("OCAB Layer Failed!")
159+
160+
assert does_pass, f"PCC check failed: {pcc_message}"

models/experimental/SSR/tests/tile_refinement/test_atten_blocks.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,9 @@ def test_atten_blocks(device, batch_size, height, width, dim, num_heads, window_
136136
# Compare outputs
137137
does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.90)
138138

139-
logger.info(f"Batch: {batch_size}, Size: {height}x{width}, Dim: {dim}")
140-
logger.info(f"Heads: {num_heads}, Window: {window_size}, Depth: {depth}")
141-
logger.info(f"Overlap ratio: {overlap_ratio}, MLP ratio: {mlp_ratio}")
142-
logger.info(f"Reference output shape: {ref_output.shape}")
143-
logger.info(f"TTNN output shape: {tt_torch_output.shape}")
144-
logger.info(pcc_message)
145-
146139
if does_pass:
147140
logger.info("AttenBlocks Passed!")
148141
else:
149142
logger.warning("AttenBlocks Failed!")
150143

151144
assert does_pass, f"PCC check failed: {pcc_message}"
152-
assert (
153-
ref_output.shape == tt_torch_output.shape
154-
), f"Shape mismatch: ref {ref_output.shape} vs ttnn {tt_torch_output.shape}"

models/experimental/SSR/tests/tile_refinement/test_patch_embed_tile_refinement.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,5 +161,3 @@ def test_patch_embed_simple(device, batch_size, img_size, patch_size, in_chans,
161161
logger.warning("TR PatchEmbed Failed!")
162162

163163
assert does_pass, f"PCC check failed: {pcc_message}"
164-
165-
ttnn.close_device(device)

models/experimental/SSR/tests/tile_refinement/test_patch_unembed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
import pytest
25
import torch
36
import ttnn
@@ -57,5 +60,3 @@ def test_tt_patch_unembed(device, batch_size, img_size, patch_size, in_chans, em
5760
logger.warning("TR PatchEmbed Failed!")
5861

5962
assert does_pass, f"PCC check failed: {pcc_message}"
60-
61-
ttnn.close_device(device)

models/experimental/SSR/tests/tile_refinement/test_tile_refinement.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
import torch
25
import pytest
36
import ttnn
@@ -118,8 +121,9 @@ def custom_preprocessor(torch_model, name, ttnn_module_args):
118121
(64, 2, 180, (6, 6, 6, 6, 6, 6), (6, 6, 6, 6, 6, 6), 16, 2, 4, (3, 3, 64, 64)),
119122
],
120123
)
124+
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
121125
def test_tile_refinement(
122-
img_size, patch_size, embed_dim, depths, num_heads, window_size, mlp_ratio, upscale, input_shape
126+
device, img_size, patch_size, embed_dim, depths, num_heads, window_size, mlp_ratio, upscale, input_shape
123127
):
124128
"""Test TTTileRefinement model against PyTorch reference"""
125129

@@ -166,8 +170,6 @@ def test_tile_refinement(
166170
# Create params dictionary
167171
params = {"rpi_sa": rpi_sa, "attn_mask": attn_mask, "rpi_oca": rpi_oca}
168172

169-
device = ttnn.open_device(device_id=0, l1_small_size=32768)
170-
171173
tt_rpi_sa = ttnn.from_torch(rpi_sa, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32)
172174

173175
tt_rpi_oca = ttnn.from_torch(rpi_oca, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.uint32)
@@ -179,11 +181,7 @@ def test_tile_refinement(
179181
# Get reference output (both image and features)
180182
with torch.no_grad():
181183
ref_output, ref_features = ref_model(x)
182-
# ref_output = ref_model(x)
183184

184-
# Open TTNN device
185-
186-
try:
187185
# Preprocess model parameters
188186
parameters = preprocess_model_parameters(
189187
initialize_model=lambda: ref_model,
@@ -247,6 +245,3 @@ def test_tile_refinement(
247245

248246
assert output_pass, f"Output comparison failed: {output_pcc_message}"
249247
assert features_pass, f"Features comparison failed: {features_pcc_message}"
250-
251-
finally:
252-
ttnn.close_device(device)

0 commit comments

Comments
 (0)