|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import ttnn |
| 4 | +from loguru import logger |
| 5 | +from PIL import Image |
| 6 | +import torchvision.transforms as transforms |
| 7 | + |
| 8 | +# Import reference and TTNN models |
| 9 | +from models.experimental.SSR.reference.SSR.model.ssr import SSR, SSR_wo_conv |
| 10 | +from models.experimental.SSR.tt.ssr import TTSSR, TTSSR_wo_conv |
| 11 | + |
| 12 | +from ttnn.model_preprocessing import preprocess_model_parameters |
| 13 | +from models.utility_functions import ( |
| 14 | + tt2torch_tensor, |
| 15 | + comp_pcc, |
| 16 | +) |
| 17 | + |
| 18 | +from models.experimental.SSR.tests.test_ssr import create_ssr_preprocessor |
| 19 | + |
| 20 | + |
| 21 | +class Args: |
| 22 | + """Args class for SSR model""" |
| 23 | + |
| 24 | + def __init__(self): |
| 25 | + self.token_size = 4 |
| 26 | + self.imgsz = 256 |
| 27 | + self.patchsz = 2 |
| 28 | + self.pretrain = False |
| 29 | + self.ckpt = None |
| 30 | + self.dim = 96 |
| 31 | + |
| 32 | + |
| 33 | +def load_image(image_path, target_size=(256, 256)): |
| 34 | + """Load and preprocess image for SSR model""" |
| 35 | + # Load image |
| 36 | + image = Image.open(image_path).convert("RGB") |
| 37 | + |
| 38 | + # Define transforms |
| 39 | + transform = transforms.Compose( |
| 40 | + [ |
| 41 | + transforms.Resize(target_size), |
| 42 | + transforms.ToTensor(), |
| 43 | + ] |
| 44 | + ) |
| 45 | + |
| 46 | + # Apply transforms and add batch dimension |
| 47 | + image_tensor = transform(image).unsqueeze(0) # Shape: (1, 3, 256, 256) |
| 48 | + |
| 49 | + return image_tensor |
| 50 | + |
| 51 | + |
| 52 | +def save_tensor_as_image(tensor, output_path): |
| 53 | + """Save tensor as image""" |
| 54 | + # Remove batch dimension and convert to numpy |
| 55 | + if tensor.dim() == 4: |
| 56 | + tensor = tensor.squeeze(0) # Remove batch dimension |
| 57 | + |
| 58 | + # Clamp values to [0, 1] range |
| 59 | + tensor = torch.clamp(tensor, 0, 1) |
| 60 | + |
| 61 | + # Convert to PIL Image |
| 62 | + transform = transforms.ToPILImage() |
| 63 | + image = transform(tensor) |
| 64 | + |
| 65 | + # Save image |
| 66 | + image.save(output_path) |
| 67 | + logger.info(f"Image saved to: {output_path}") |
| 68 | + |
| 69 | + |
| 70 | +def run_ssr_inference(input_image_path, output_dir="models/experimental/SSR/demo/images/", with_conv=True): |
| 71 | + """Run SSR model inference on input image""" |
| 72 | + |
| 73 | + # Load input image |
| 74 | + logger.info(f"Loading image from: {input_image_path}") |
| 75 | + x = load_image(input_image_path) |
| 76 | + logger.info(f"Input image shape: {x.shape}") |
| 77 | + |
| 78 | + # Create output directory if it doesn't exist |
| 79 | + if not os.path.exists(output_dir): |
| 80 | + os.makedirs(output_dir) |
| 81 | + |
| 82 | + # Create args |
| 83 | + args = Args() |
| 84 | + num_cls = 1 |
| 85 | + |
| 86 | + # Create reference PyTorch model |
| 87 | + if with_conv: |
| 88 | + ref_model = SSR(args, num_cls) |
| 89 | + else: |
| 90 | + ref_model = SSR_wo_conv(args, num_cls) |
| 91 | + ref_model.eval() |
| 92 | + |
| 93 | + # Get reference output |
| 94 | + logger.info("Running PyTorch reference model...") |
| 95 | + with torch.no_grad(): |
| 96 | + ref_sr, ref_patch_fea3, ref_patch_fea2, ref_patch_fea1 = ref_model(x) |
| 97 | + |
| 98 | + # Save reference output |
| 99 | + ref_output_path = os.path.join(output_dir, "reference_output.png") |
| 100 | + logger.info("Saving PyTorch reference output...") |
| 101 | + save_tensor_as_image(ref_sr, ref_output_path) |
| 102 | + |
| 103 | + # Open TTNN device with larger L1 cache to handle memory requirements |
| 104 | + device = ttnn.open_device(device_id=0, l1_small_size=32768) |
| 105 | + |
| 106 | + try: |
| 107 | + # Preprocess model parameters |
| 108 | + logger.info("Preprocessing model parameters...") |
| 109 | + parameters = preprocess_model_parameters( |
| 110 | + initialize_model=lambda: ref_model, |
| 111 | + custom_preprocessor=create_ssr_preprocessor(device, args, num_cls), |
| 112 | + device=device, |
| 113 | + ) |
| 114 | + |
| 115 | + # Create TTNN model |
| 116 | + logger.info("Creating TTNN model...") |
| 117 | + if with_conv: |
| 118 | + tt_model = TTSSR( |
| 119 | + device=device, |
| 120 | + parameters=parameters, |
| 121 | + args=args, |
| 122 | + num_cls=num_cls, |
| 123 | + ) |
| 124 | + else: |
| 125 | + tt_model = TTSSR_wo_conv( |
| 126 | + device=device, |
| 127 | + parameters=parameters, |
| 128 | + args=args, |
| 129 | + num_cls=num_cls, |
| 130 | + ) |
| 131 | + |
| 132 | + # Convert input to TTNN tensor |
| 133 | + logger.info("Converting input to TTNN tensor...") |
| 134 | + tt_input = ttnn.from_torch(x, device=device, layout=ttnn.TILE_LAYOUT) |
| 135 | + |
| 136 | + # Run TTNN model |
| 137 | + logger.info("Running TTNN model inference...") |
| 138 | + tt_sr, tt_patch_fea3 = tt_model(tt_input) |
| 139 | + |
| 140 | + # Convert back to torch tensors |
| 141 | + logger.info("Converting outputs back to torch tensors...") |
| 142 | + tt_torch_sr = tt2torch_tensor(tt_sr) |
| 143 | + tt_torch_sr = tt_torch_sr.permute(0, 3, 1, 2) |
| 144 | + |
| 145 | + # Save TTNN output image |
| 146 | + ttnn_output_path = os.path.join(output_dir, "ttnn_output.png") |
| 147 | + logger.info("Saving TTNN super-resolved image...") |
| 148 | + save_tensor_as_image(tt_torch_sr, ttnn_output_path) |
| 149 | + |
| 150 | + # Compare outputs (optional - for validation) |
| 151 | + sr_pass, sr_pcc_message = comp_pcc(ref_sr, tt_torch_sr, 0.95) |
| 152 | + logger.info(f"SR Output PCC: {sr_pcc_message}") |
| 153 | + |
| 154 | + if sr_pass: |
| 155 | + logger.info("TTSSR inference completed successfully!") |
| 156 | + else: |
| 157 | + logger.warning("TTSSR inference completed with quality concerns.") |
| 158 | + |
| 159 | + logger.info(f"Reference output saved to: {ref_output_path}") |
| 160 | + logger.info(f"TTNN output saved to: {ttnn_output_path}") |
| 161 | + |
| 162 | + return tt_torch_sr, ref_sr |
| 163 | + |
| 164 | + finally: |
| 165 | + ttnn.close_device(device) |
| 166 | + |
| 167 | + |
| 168 | +if __name__ == "__main__": |
| 169 | + import argparse |
| 170 | + |
| 171 | + parser = argparse.ArgumentParser(description="SSR Super-Resolution Inference") |
| 172 | + parser.add_argument( |
| 173 | + "--input", |
| 174 | + type=str, |
| 175 | + default="models/experimental/SSR/demo/images/ssr_test_image.jpg", |
| 176 | + help="Path to input image", |
| 177 | + ) |
| 178 | + parser.add_argument( |
| 179 | + "--output-dir", type=str, default="models/experimental/SSR/demo/images/", help="Directory to save output images" |
| 180 | + ) |
| 181 | + parser.add_argument("--with-conv", action="store_true", default=False, help="Use SSR model with conv layers") |
| 182 | + |
| 183 | + args = parser.parse_args() |
| 184 | + |
| 185 | + run_ssr_inference(args.input, args.output_dir, args.with_conv) |
0 commit comments