Skip to content

Commit 3558057

Browse files
patch_embed_tr test file name fix in imports, demo files init wip
1 parent 4ddba24 commit 3558057

File tree

3 files changed

+187
-2
lines changed

3 files changed

+187
-2
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import os
2+
import torch
3+
import ttnn
4+
from loguru import logger
5+
from PIL import Image
6+
import torchvision.transforms as transforms
7+
8+
# Import reference and TTNN models
9+
from models.experimental.SSR.reference.SSR.model.ssr import SSR, SSR_wo_conv
10+
from models.experimental.SSR.tt.ssr import TTSSR, TTSSR_wo_conv
11+
12+
from ttnn.model_preprocessing import preprocess_model_parameters
13+
from models.utility_functions import (
14+
tt2torch_tensor,
15+
comp_pcc,
16+
)
17+
18+
from models.experimental.SSR.tests.test_ssr import create_ssr_preprocessor
19+
20+
21+
class Args:
22+
"""Args class for SSR model"""
23+
24+
def __init__(self):
25+
self.token_size = 4
26+
self.imgsz = 256
27+
self.patchsz = 2
28+
self.pretrain = False
29+
self.ckpt = None
30+
self.dim = 96
31+
32+
33+
def load_image(image_path, target_size=(256, 256)):
34+
"""Load and preprocess image for SSR model"""
35+
# Load image
36+
image = Image.open(image_path).convert("RGB")
37+
38+
# Define transforms
39+
transform = transforms.Compose(
40+
[
41+
transforms.Resize(target_size),
42+
transforms.ToTensor(),
43+
]
44+
)
45+
46+
# Apply transforms and add batch dimension
47+
image_tensor = transform(image).unsqueeze(0) # Shape: (1, 3, 256, 256)
48+
49+
return image_tensor
50+
51+
52+
def save_tensor_as_image(tensor, output_path):
53+
"""Save tensor as image"""
54+
# Remove batch dimension and convert to numpy
55+
if tensor.dim() == 4:
56+
tensor = tensor.squeeze(0) # Remove batch dimension
57+
58+
# Clamp values to [0, 1] range
59+
tensor = torch.clamp(tensor, 0, 1)
60+
61+
# Convert to PIL Image
62+
transform = transforms.ToPILImage()
63+
image = transform(tensor)
64+
65+
# Save image
66+
image.save(output_path)
67+
logger.info(f"Image saved to: {output_path}")
68+
69+
70+
def run_ssr_inference(input_image_path, output_dir="models/experimental/SSR/demo/images/", with_conv=True):
71+
"""Run SSR model inference on input image"""
72+
73+
# Load input image
74+
logger.info(f"Loading image from: {input_image_path}")
75+
x = load_image(input_image_path)
76+
logger.info(f"Input image shape: {x.shape}")
77+
78+
# Create output directory if it doesn't exist
79+
if not os.path.exists(output_dir):
80+
os.makedirs(output_dir)
81+
82+
# Create args
83+
args = Args()
84+
num_cls = 1
85+
86+
# Create reference PyTorch model
87+
if with_conv:
88+
ref_model = SSR(args, num_cls)
89+
else:
90+
ref_model = SSR_wo_conv(args, num_cls)
91+
ref_model.eval()
92+
93+
# Get reference output
94+
logger.info("Running PyTorch reference model...")
95+
with torch.no_grad():
96+
ref_sr, ref_patch_fea3, ref_patch_fea2, ref_patch_fea1 = ref_model(x)
97+
98+
# Save reference output
99+
ref_output_path = os.path.join(output_dir, "reference_output.png")
100+
logger.info("Saving PyTorch reference output...")
101+
save_tensor_as_image(ref_sr, ref_output_path)
102+
103+
# Open TTNN device with larger L1 cache to handle memory requirements
104+
device = ttnn.open_device(device_id=0, l1_small_size=32768)
105+
106+
try:
107+
# Preprocess model parameters
108+
logger.info("Preprocessing model parameters...")
109+
parameters = preprocess_model_parameters(
110+
initialize_model=lambda: ref_model,
111+
custom_preprocessor=create_ssr_preprocessor(device, args, num_cls),
112+
device=device,
113+
)
114+
115+
# Create TTNN model
116+
logger.info("Creating TTNN model...")
117+
if with_conv:
118+
tt_model = TTSSR(
119+
device=device,
120+
parameters=parameters,
121+
args=args,
122+
num_cls=num_cls,
123+
)
124+
else:
125+
tt_model = TTSSR_wo_conv(
126+
device=device,
127+
parameters=parameters,
128+
args=args,
129+
num_cls=num_cls,
130+
)
131+
132+
# Convert input to TTNN tensor
133+
logger.info("Converting input to TTNN tensor...")
134+
tt_input = ttnn.from_torch(x, device=device, layout=ttnn.TILE_LAYOUT)
135+
136+
# Run TTNN model
137+
logger.info("Running TTNN model inference...")
138+
tt_sr, tt_patch_fea3 = tt_model(tt_input)
139+
140+
# Convert back to torch tensors
141+
logger.info("Converting outputs back to torch tensors...")
142+
tt_torch_sr = tt2torch_tensor(tt_sr)
143+
tt_torch_sr = tt_torch_sr.permute(0, 3, 1, 2)
144+
145+
# Save TTNN output image
146+
ttnn_output_path = os.path.join(output_dir, "ttnn_output.png")
147+
logger.info("Saving TTNN super-resolved image...")
148+
save_tensor_as_image(tt_torch_sr, ttnn_output_path)
149+
150+
# Compare outputs (optional - for validation)
151+
sr_pass, sr_pcc_message = comp_pcc(ref_sr, tt_torch_sr, 0.95)
152+
logger.info(f"SR Output PCC: {sr_pcc_message}")
153+
154+
if sr_pass:
155+
logger.info("TTSSR inference completed successfully!")
156+
else:
157+
logger.warning("TTSSR inference completed with quality concerns.")
158+
159+
logger.info(f"Reference output saved to: {ref_output_path}")
160+
logger.info(f"TTNN output saved to: {ttnn_output_path}")
161+
162+
return tt_torch_sr, ref_sr
163+
164+
finally:
165+
ttnn.close_device(device)
166+
167+
168+
if __name__ == "__main__":
169+
import argparse
170+
171+
parser = argparse.ArgumentParser(description="SSR Super-Resolution Inference")
172+
parser.add_argument(
173+
"--input",
174+
type=str,
175+
default="models/experimental/SSR/demo/images/ssr_test_image.jpg",
176+
help="Path to input image",
177+
)
178+
parser.add_argument(
179+
"--output-dir", type=str, default="models/experimental/SSR/demo/images/", help="Directory to save output images"
180+
)
181+
parser.add_argument("--with-conv", action="store_true", default=False, help="Use SSR model with conv layers")
182+
183+
args = parser.parse_args()
184+
185+
run_ssr_inference(args.input, args.output_dir, args.with_conv)

models/experimental/SSR/tests/test_RHAG.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from models.experimental.SSR.tests.test_atten_blocks import create_atten_blocks_preprocessor
1313

1414
# from models.experimental.SSR.tests.test_patch_embed import create_patch_embed_preprocessor
15-
from models.experimental.SSR.tests.test_patch_embed_tile_selection import create_patch_embed_preprocessor_conv
15+
from models.experimental.SSR.tests.test_patch_embed_tile_refinement import create_patch_embed_preprocessor_conv
1616
from models.utility_functions import comp_pcc
1717
from ttnn.model_preprocessing import preprocess_model_parameters
1818

models/experimental/SSR/tests/test_tile_refinement.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# Import the reference models (adjust import paths as needed)
77
from models.experimental.SSR.reference.SSR.model.tile_refinement import TileRefinement
8-
from models.experimental.SSR.tests.test_patch_embed_tile_selection import create_patch_embed_preprocessor_conv
8+
from models.experimental.SSR.tests.test_patch_embed_tile_refinement import create_patch_embed_preprocessor_conv
99
from models.experimental.SSR.tests.test_RHAG import create_rhag_preprocessor
1010
from models.experimental.SSR.tests.test_upsample import create_upsample_preprocessor
1111
from models.experimental.SSR.tt.tile_refinement import TTTileRefinement

0 commit comments

Comments
 (0)