diff --git a/models/experimental/SSR/README.md b/models/experimental/SSR/README.md new file mode 100644 index 000000000000..0fa176956ad4 --- /dev/null +++ b/models/experimental/SSR/README.md @@ -0,0 +1,68 @@ +# Transformer-based Selective Super-Resolution for Efficient Image Refinement (SSR) + +## Platforms: + Wormhole (n150) + +## Introduction +Selective Super-Resolution (SSR) is a transformer-based framework that partitions an image into tiles, uses a multi‑scale pyramid to select only object‑relevant regions, and applies deep refinement exclusively where it matters, avoiding background over‑sharpening. By skipping heavy processing on unimportant tiles, SSR cuts computation roughly 40% while improving visual fidelity for downstream tasks, achieving large quality gains such as reducing FID on BDD100K from 26.78 to 10.41. + +[Link to paper](https://arxiv.org/abs/2312.05803) + +**NOTE:** Trained weights are not available at this time. The implementation uses random weights to ensure correctness against the reference implementation. + +## Structure of model + +``` +SSR +├─ TileSelection +│ ├─ Patch Embed +│ ├─ MaskTokenInference +│ └─ Basic Layer +│ ├─ PatchMerging +│ └─ Swin Transformer Block +│ ├─ Window Attention +│ └─ MLP +└ TileRefinement + └─ HAT + ├─ Patch Embed + ├─ Patch Unembed + ├─ RHAG + │ └─ Attention Blocks + │ ├─ HAB + │ │ ├─ Window Attention + │ │ ├─ CAB + │ │ │ └─ Channel Attention + │ │ └─ MLP + │ └─ OCAB + │ └─ MLP + ├─ Patch Embed + ├─ Patch Unembed + └─ Upsample +``` + +## Prerequisites +- Cloned [tt-metal repository](https://github.com/tenstorrent/tt-metal) for source code +- Installed: [TT-Metalium™ / TT-NN™](https://github.com/tenstorrent/tt-metal/blob/main/INSTALLING.md) + - To obtain the perf reports through profiler, please build with: `./build_metal.sh -p` + +## How to Run +### For performance run (Slightly lesser PCC, lesser device time): +``` +pytest models/experimental/SSR/tests/test_ssr.py -k "performance" +``` +### For accuracy run (Slightly more PCC, more device time): +``` +pytest models/experimental/SSR/tests/test_ssr.py -k "accuracy" +``` + + +### Demo +``` +python models/experimental/SSR/demo/ssr_demo.py --accuracy --depth 6 --num_heads 6 +``` +NOTE: If --input is not provided, the demo uses the default image located at models/experimental/SSR/demo/images/ssr_test_image.jpg. Make sure this file exists or the demo will fail. + +## Details +- The entry point to the `SSR` is located at:`models/experimental/SSR/tt/ssr.py`. +- Batch Size : `1` (Single Device). +- Supported Input Resolution - `(256, 256)` - (Height, Width). diff --git a/models/experimental/SSR/demo/ssr_demo.py b/models/experimental/SSR/demo/ssr_demo.py new file mode 100644 index 000000000000..586df2cc0218 --- /dev/null +++ b/models/experimental/SSR/demo/ssr_demo.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os +import torch +import ttnn +import argparse +from loguru import logger +from PIL import Image +import torchvision.transforms as transforms + +from models.experimental.SSR.reference.SSR.model.ssr import SSR, SSR_wo_conv +from models.experimental.SSR.tt.ssr import TTSSR, TTSSR_wo_conv + +from models.experimental.SSR.reference.SSR.model.net_blocks import window_reverse +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import tt2torch_tensor +from models.utility_functions import ( + comp_pcc, +) + +from models.experimental.SSR.tests.test_ssr import create_ssr_preprocessor + + +class Args: + """Args class for SSR model""" + + def __init__(self): + self.token_size = 4 + self.imgsz = 256 + self.patchsz = 2 + self.pretrain = False + self.ckpt = None + self.dim = 96 + + +def load_image(image_path, target_size=(256, 256)): + """Load and preprocess image for SSR model""" + # Load image + image = Image.open(image_path).convert("RGB") + + # Define transforms + transform = transforms.Compose( + [ + transforms.Resize(target_size), + transforms.ToTensor(), + ] + ) + + # Apply transforms and add batch dimension + image_tensor = transform(image).unsqueeze(0) # Shape: (1, 3, 256, 256) + + return image_tensor + + +def save_tensor_as_image(tensor, output_path): + """Save tensor as image""" + # Remove batch dimension and convert to numpy + if tensor.dim() == 4: + tensor = tensor.squeeze(0) # Remove batch dimension + + # Convert BFloat16 to Float32 if needed + if tensor.dtype == torch.bfloat16: + tensor = tensor.to(torch.float32) + + # Clamp values to [0, 1] range + tensor = torch.clamp(tensor, 0, 1) + + # Convert to PIL Image + transform = transforms.ToPILImage() + image = transform(tensor) + + # Save image + image.save(output_path) + logger.info(f"Image saved to: {output_path}") + + +def run_ssr_inference( + input_image_path, + output_dir="models/experimental/SSR/demo/images/", + with_conv=False, + accuracy_mode=False, + depth="1", + num_heads="1", +): + """Run SSR model inference on input image""" + + # Load input image + logger.info(f"Loading image from: {input_image_path}") + x = load_image(input_image_path) + logger.info(f"Input image shape: {x.shape}") + + torch.manual_seed(0) + + # Create output directory if it doesn't exist + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + depth_map = {"1": [1], "6": [6, 6, 6, 6, 6, 6]} + num_heads_map = {"1": [1], "6": [6, 6, 6, 6, 6, 6]} + actual_depth = depth_map[depth] + actual_num_heads = num_heads_map[num_heads] + + # Create args + args = Args() + num_cls = 1 + + # Create reference PyTorch model + if with_conv: + ref_model = SSR(args, num_cls, depth=actual_depth, num_heads=actual_num_heads) + else: + ref_model = SSR_wo_conv(args, num_cls, depth=actual_depth, num_heads=actual_num_heads) + ref_model.eval() + + # Get reference output + logger.info("Running PyTorch reference model...") + with torch.no_grad(): + ref_sr, ref_patch_fea3, ref_patch_fea2, ref_patch_fea1 = ref_model(x) + + # Save reference output + ref_output_path = os.path.join(output_dir, "reference_output.png") + logger.info("Saving PyTorch reference output...") + save_tensor_as_image(ref_sr, ref_output_path) + + # Open TTNN device with larger L1 cache to handle memory requirements + device = ttnn.open_device(device_id=0, l1_small_size=32768) + + try: + # Preprocess model parameters + logger.info("Preprocessing model parameters...") + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_model, + custom_preprocessor=create_ssr_preprocessor(device, args, num_cls, actual_depth), + device=device, + ) + + # Create TTNN model + logger.info("Creating TTNN model...") + if with_conv: + tt_model = TTSSR( + device=device, + parameters=parameters, + args=args, + num_cls=num_cls, + depth=actual_depth, + num_heads=actual_num_heads, + ) + else: + tt_model = TTSSR_wo_conv( + device=device, + parameters=parameters, + args=args, + num_cls=num_cls, + depth=actual_depth, + num_heads=actual_num_heads, + ) + + # Convert input to TTNN tensor + logger.info("Converting input to TTNN tensor...") + tt_input = ttnn.from_torch( + x, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16 if accuracy_mode else ttnn.bfloat8_b + ) + + # Run TTNN model + logger.info("Running TTNN model inference...") + tt_sr, tt_patch_fea3 = tt_model(tt_input) + # Convert back to torch tensors + tt_torch_sr = tt2torch_tensor(tt_sr) + tt_torch_patch_fea3 = tt2torch_tensor(tt_patch_fea3) + tt_torch_sr = tt_torch_sr.permute(0, 3, 1, 2) + if not with_conv: + _, _, H, W = x.shape + tt_torch_sr = window_reverse(tt_torch_sr.permute(0, 2, 3, 1), window_size=H, H=H * 4, W=W * 4) + tt_torch_sr = tt_torch_sr.permute(0, 3, 1, 2) + + # Save TTNN output image + ttnn_output_path = os.path.join(output_dir, "ttnn_output.png") + logger.info("Saving TTNN super-resolved image...") + save_tensor_as_image(tt_torch_sr, ttnn_output_path) + + # Compare outputs (optional - for validation) + sr_pass, sr_pcc_message = comp_pcc(ref_sr, tt_torch_sr, 0.90) + logger.info(f"SR Output PCC: {sr_pcc_message}") + + if sr_pass: + logger.info("TTSSR inference completed successfully!") + else: + logger.warning("TTSSR inference completed with quality concerns.") + + logger.info(f"Reference output saved to: {ref_output_path}") + logger.info(f"TTNN output saved to: {ttnn_output_path}") + + return tt_sr, ref_sr + + finally: + ttnn.close_device(device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="SSR Super-Resolution Inference") + parser.add_argument( + "--input", + type=str, + default="models/experimental/SSR/demo/images/ssr_test_image.jpg", + help="Path to input image", + ) + parser.add_argument( + "--output-dir", type=str, default="models/experimental/SSR/demo/images/", help="Directory to save output images" + ) + parser.add_argument("--with-conv", action="store_true", default=False, help="Use SSR model with conv layers") + parser.add_argument("--accuracy", action="store_true", default=False, help="Set flag to run in bfloat16 precision") + parser.add_argument("--depth", choices=["1", "6"], default="6", help="SSR depth configuration") + parser.add_argument("--num_heads", choices=["1", "6"], default="6", help="SSR num heads configuration") + + args = parser.parse_args() + + run_ssr_inference(args.input, args.output_dir, args.with_conv, args.accuracy, args.depth, args.num_heads) diff --git a/models/experimental/SSR/reference/SSR/model/net_blocks.py b/models/experimental/SSR/reference/SSR/model/net_blocks.py new file mode 100644 index 000000000000..7909a332d94e --- /dev/null +++ b/models/experimental/SSR/reference/SSR/model/net_blocks.py @@ -0,0 +1,735 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + + +# Reference: +# https://github.com/destiny301/SSR + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +import torch.nn.functional as F +from typing import Tuple +import numpy as np + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_partition_padding(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r"""Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows, pad_hw = window_partition_padding( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows, pad_hw = window_partition_padding(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_unpartition(attn_windows, self.window_size, pad_hw, (H, W)) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class PatchExpand(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() + self.norm = norm_layer(dim // dim_scale) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, "b h w (p1 p2 c)-> b (h p1) (w p2) c", p1=2, p2=2, c=C // 4) + x = x.view(B, -1, C // 4) + x = self.norm(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # Setting for convenient params extraction in test + self.window_size = window_size + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class MutualAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.scale = qk_scale or head_dim**-0.5 + + self.rgb_q = nn.Linear(dim, dim, bias=qkv_bias) + self.rgb_k = nn.Linear(dim, dim, bias=qkv_bias) + self.rgb_v = nn.Linear(dim, dim, bias=qkv_bias) + self.rgb_proj = nn.Linear(dim, dim) + + self.depth_q = nn.Linear(dim, dim, bias=qkv_bias) + self.depth_k = nn.Linear(dim, dim, bias=qkv_bias) + self.depth_v = nn.Linear(dim, dim, bias=qkv_bias) + self.depth_proj = nn.Linear(dim, dim) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, rgb_fea, depth_fea): + B, N, C = rgb_fea.shape + + rgb_q = self.rgb_q(rgb_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + rgb_k = self.rgb_k(rgb_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + rgb_v = self.rgb_v(rgb_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + # q [B, nhead, N, C//nhead] + + depth_q = self.depth_q(depth_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + depth_k = self.depth_k(depth_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + depth_v = self.depth_v(depth_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + # rgb branch + rgb_attn = (rgb_q @ depth_k.transpose(-2, -1)) * self.scale + rgb_attn = rgb_attn.softmax(dim=-1) + rgb_attn = self.attn_drop(rgb_attn) + + rgb_fea = (rgb_attn @ depth_v).transpose(1, 2).reshape(B, N, C) + rgb_fea = self.rgb_proj(rgb_fea) + rgb_fea = self.proj_drop(rgb_fea) + + # depth branch + depth_attn = (depth_q @ rgb_k.transpose(-2, -1)) * self.scale + depth_attn = depth_attn.softmax(dim=-1) + depth_attn = self.attn_drop(depth_attn) + + depth_fea = (depth_attn @ rgb_v).transpose(1, 2).reshape(B, N, C) + depth_fea = self.depth_proj(depth_fea) + depth_fea = self.proj_drop(depth_fea) + + return rgb_fea, depth_fea + + +def get_sinusoid_encoding(n_position, d_hid): + """Sinusoid position encoding table""" + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +class BasicLayer_up(nn.Module): + """A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + upsample=None, + use_checkpoint=False, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if upsample is not None: + self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) + else: + self.upsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.upsample is not None: + x = self.upsample(x) + return x + + +class PatchEmbed(nn.Module): + r"""Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/models/experimental/SSR/reference/SSR/model/ssr.py b/models/experimental/SSR/reference/SSR/model/ssr.py new file mode 100644 index 000000000000..586b266a70b7 --- /dev/null +++ b/models/experimental/SSR/reference/SSR/model/ssr.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + + +# Reference: +# https://github.com/destiny301/SSR + +import torch.nn as nn +import torch + +from models.experimental.SSR.reference.SSR.model.tile_refinement import TileRefinement, Upsample +from models.experimental.SSR.reference.SSR.model.tile_selection import TileSelection +from models.experimental.SSR.reference.SSR.model.net_blocks import window_partition, window_reverse + + +class SSR(nn.Module): + """ + feed pos tiles to TR Module, neg tiles to conv layers, then reconstruct them together + """ + + def __init__(self, args, num_cls, depth, num_heads) -> None: + super().__init__() + self.select_model = TileSelection(args, num_cls) + self.sr_model = TileRefinement( + upscale=4, + img_size=64, + window_size=16, + img_range=1.0, + depths=depth, + embed_dim=180, + num_heads=num_heads, + mlp_ratio=2, + upsampler="pixelshuffle", + ) + + if args.pretrain: + self.sr_model.load_state_dict(torch.load(args.ckpt)["params_ema"]) + print("----------loaded TR pretrained model-----------------") + + # image reconstruction + self.conv_first = nn.Conv2d(3, 180, 3, 1, 1) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(180, 64, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(4, 64) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + + def forward(self, x): + B, C, H, W = x.shape + patch_fea3, patch_fea2, patch_fea1 = self.select_model(x) + pi_prime = patch_fea3.view(-1) + pi_prime = pi_prime > torch.quantile(pi_prime, 0.75) + # pi_prime = gumbel_sigmoid(patch_fea3, hard=True) # top25% or gumbel_sigmoid hard selection + patch_x = window_partition(x.permute(0, 2, 3, 1), window_size=(H // 4)) # B*4*4, H/4, W/4, 3 + patch_x = patch_x.permute(0, 3, 1, 2) # B*4*4, 3, H/4, W/4 + pi_prime = pi_prime.view(-1) + + # feature extraction + lr_fea = torch.zeros((0, 180, 64, 64)).to(x.device) + for i in range(B * 16): + if pi_prime[i] == 1: + posX, fea = self.sr_model(patch_x[i].unsqueeze(0)) + lr_fea = torch.cat([lr_fea, fea], dim=0) + else: + fea = self.conv_first(patch_x[i].unsqueeze(0)) + lr_fea = torch.cat([lr_fea, fea], dim=0) + lr_fea = window_reverse(lr_fea.permute(0, 2, 3, 1), window_size=H // 4, H=H, W=W).permute(0, 3, 1, 2) + + # image reconstruction + sr_fea = self.upsample(self.conv_before_upsample(lr_fea)) + sr = self.conv_last(sr_fea) + + return sr, patch_fea3, patch_fea2, patch_fea1 + + +class SSR_wo_conv(nn.Module): + """ + simply feed pos tiles to TR module, neg tiles to upsample layer + """ + + def __init__(self, args, num_cls, depth, num_heads) -> None: + super().__init__() + self.select_model = TileSelection(args, num_cls) + self.sr_model = TileRefinement( + upscale=4, + img_size=64, + window_size=16, + img_range=1.0, + depths=depth, + embed_dim=180, + num_heads=num_heads, + mlp_ratio=2, + upsampler="pixelshuffle", + ) + + if args.pretrain: + self.sr_model.load_state_dict(torch.load(args.ckpt)["params_ema"]) + print("----------loaded TR pretrained model-----------------") + + self.upsample = nn.Upsample(scale_factor=4, mode="bicubic") + + def forward(self, x): + B, C, H, W = x.shape + patch_fea3, patch_fea2, patch_fea1 = self.select_model(x) + pi_prime = patch_fea3.view(-1) + pi_prime = pi_prime > torch.quantile(pi_prime, 0.75) + # pi_prime = gumbel_sigmoid(patch_fea3, hard=True) # top25% or gumbel_sigmoid hard selection + patch_x = window_partition(x.permute(0, 2, 3, 1), window_size=(H // 4)) # B*4*4, H/4, W/4, 3 + patch_x = patch_x.permute(0, 3, 1, 2) # B*4*4, 3, H/4, W/4 + pi_prime = pi_prime.view(-1) + + sr = torch.zeros((0, C, H, W)).to(x.device) + + for i in range(B * 16): + if pi_prime[i] == 1: + posX, _ = self.sr_model(patch_x[i].unsqueeze(0)) + sr = torch.cat([sr, posX], dim=0) + else: + negX = self.upsample(patch_x[i].unsqueeze(0)) + sr = torch.cat([sr, negX], dim=0) + + sr = window_reverse(sr.permute(0, 2, 3, 1), window_size=H, H=H * 4, W=W * 4) + + return sr.permute(0, 3, 1, 2), patch_fea3, patch_fea2, patch_fea1 diff --git a/models/experimental/SSR/reference/SSR/model/tile_refinement.py b/models/experimental/SSR/reference/SSR/model/tile_refinement.py new file mode 100644 index 000000000000..341541eda3e2 --- /dev/null +++ b/models/experimental/SSR/reference/SSR/model/tile_refinement.py @@ -0,0 +1,1122 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + + +# Reference: +# https://github.com/destiny301/SSR + + +import math +import torch +import torch.nn as nn + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from einops import rearrange + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class ChannelAttention(nn.Module): + """Channel attention used in RCAN. + Args: + num_feat (int): Channel number of intermediate features. + squeeze_factor (int): Channel squeeze factor. Default: 16. + """ + + def __init__(self, num_feat, squeeze_factor=16): + super(ChannelAttention, self).__init__() + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), + nn.ReLU(inplace=True), + nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), + nn.Sigmoid(), + ) + + def forward(self, x): + y = self.attention(x) + return x * y + + +class CAB(nn.Module): + def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30): + super(CAB, self).__init__() + + self.cab = nn.Sequential( + nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1), + nn.GELU(), + nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1), + ChannelAttention(num_feat, squeeze_factor), + ) + + def forward(self, x): + return self.cab(x) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (b, h, w, c) + window_size (int): window size + + Returns: + windows: (num_windows*b, window_size, window_size, c) + """ + b, h, w, c = x.shape + x = x.view(b, h // window_size, window_size, w // window_size, window_size, c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) + return windows + + +def window_reverse(windows, window_size, h, w): + """ + Args: + windows: (num_windows*b, window_size, window_size, c) + window_size (int): Window size + h (int): Height of image + w (int): Width of image + + Returns: + x: (b, h, w, c) + """ + b = int(windows.shape[0] / (h * w / window_size / window_size)) + x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) + return x + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, rpi, mask=None): + """ + Args: + x: input features with shape of (num_windows*b, n, c) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + b_, n, c = x.shape + qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nw = mask.shape[0] + attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(b_, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class HAB(nn.Module): + r"""Hybrid Attention Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.conv_scale = conv_scale + self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size, rpi_sa, attn_mask): + h, w = x_size + b, _, c = x.shape + # assert seq_len == h * w, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(b, h, w, c) + + # Conv_X + conv_x = self.conv_block(x.permute(0, 3, 1, 2)) + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = attn_mask + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c + x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c) + shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c + + # reverse cyclic shift + if self.shift_size > 0: + attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attn_x = shifted_x + attn_x = attn_x.view(b, h * w, c) + + # FFN + x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: b, h*w, c + """ + h, w = self.input_resolution + b, seq_len, c = x.shape + assert seq_len == h * w, "input feature has wrong size" + assert h % 2 == 0 and w % 2 == 0, f"x size ({h}*{w}) are not even." + + x = x.view(b, h, w, c) + + x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c + x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c + x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c + x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c + x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c + x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class OCAB(nn.Module): + # overlapping cross-attention block + + def __init__( + self, + dim, + input_resolution, + window_size, + overlap_ratio, + num_heads, + qkv_bias=True, + qk_scale=None, + mlp_ratio=2, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.overlap_win_size = int(window_size * overlap_ratio) + window_size + + self.norm1 = norm_layer(dim) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.unfold = nn.Unfold( + kernel_size=(self.overlap_win_size, self.overlap_win_size), + stride=window_size, + padding=(self.overlap_win_size - window_size) // 2, + ) + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (window_size + self.overlap_win_size - 1) * (window_size + self.overlap_win_size - 1), num_heads + ) + ) # 2*Wh-1 * 2*Ww-1, nH + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + self.proj = nn.Linear(dim, dim) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU) + + def forward(self, x, x_size, rpi): + h, w = x_size + b, _, c = x.shape + + shortcut = x + x = self.norm1(x) + x = x.view(b, h, w, c) + + qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2) # 3, b, c, h, w + q = qkv[0].permute(0, 2, 3, 1) # b, h, w, c + kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w + + # partition windows + q_windows = window_partition(q, self.window_size) # nw*b, window_size, window_size, c + q_windows = q_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c + + kv_windows = self.unfold(kv) # b, c*w*w, nw + kv_windows = rearrange( + kv_windows, + "b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch", + nc=2, + ch=c, + owh=self.overlap_win_size, + oww=self.overlap_win_size, + ).contiguous() # 2, nw*b, ow*ow, c + k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c + + b_, nq, _ = q_windows.shape + _, n, _ = k_windows.shape + d = self.dim // self.num_heads + q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, nq, d + k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d + v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view( + self.window_size * self.window_size, self.overlap_win_size * self.overlap_win_size, -1 + ) # ws*ws, wse*wse, nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, ws*ws, wse*wse + attn = attn + relative_position_bias.unsqueeze(0) + + attn = self.softmax(attn) + attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim) + x = window_reverse(attn_windows, self.window_size, h, w) # b h w c + x = x.view(b, h * w, self.dim) + + x = self.proj(x) + shortcut + + x = x + self.mlp(self.norm2(x)) + return x + + +class AttenBlocks(nn.Module): + """A series of attention blocks for one RHAG. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + compress_ratio, + squeeze_factor, + conv_scale, + overlap_ratio, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + HAB( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + compress_ratio=compress_ratio, + squeeze_factor=squeeze_factor, + conv_scale=conv_scale, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # OCAB + self.overlap_attn = OCAB( + dim=dim, + input_resolution=input_resolution, + window_size=window_size, + overlap_ratio=overlap_ratio, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size, params): + for blk in self.blocks: + x = blk(x, x_size, params["rpi_sa"], params["attn_mask"]) + + x = self.overlap_attn(x, x_size, params["rpi_oca"]) + + if self.downsample is not None: + x = self.downsample(x) + return x + + +class RHAG(nn.Module): + """Residual Hybrid Attention Group (RHAG). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + compress_ratio, + squeeze_factor, + conv_scale, + overlap_ratio, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + img_size=224, + patch_size=4, + resi_connection="1conv", + ): + super(RHAG, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = AttenBlocks( + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + compress_ratio=compress_ratio, + squeeze_factor=squeeze_factor, + conv_scale=conv_scale, + overlap_ratio=overlap_ratio, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + ) + + if resi_connection == "1conv": + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == "identity": + self.conv = nn.Identity() + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None + ) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None + ) + + def forward(self, x, x_size, params): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x + + +class PatchEmbed(nn.Module): + r"""Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # b Ph*Pw c + if self.norm is not None: + x = self.norm(x) + return x + + +class PatchUnEmbed(nn.Module): + r"""Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + x = x.transpose(1, 2).contiguous().view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c + return x + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f"scale {scale} is not supported. " "Supported scales: 2^n and 3.") + super(Upsample, self).__init__(*m) + + +# @ARCH_REGISTRY.register() +class HAT(nn.Module): + r"""Hybrid Attention Transformer + A PyTorch implementation of : `Activating More Pixels in Image Super-Resolution Transformer`. + Some codes are based on SwinIR. + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__( + self, + img_size=64, + patch_size=1, + in_chans=3, + embed_dim=96, + depths=(6, 6, 6, 6), + num_heads=(6, 6, 6, 6), + window_size=7, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + overlap_ratio=0.5, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + upscale=2, + img_range=1.0, + upsampler="", + resi_connection="1conv", + **kwargs, + ): + super(HAT, self).__init__() + + self.window_size = window_size + self.shift_size = window_size // 2 + self.overlap_ratio = overlap_ratio + + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + + # relative position index + relative_position_index_SA = self.calculate_rpi_sa() + relative_position_index_OCA = self.calculate_rpi_oca() + self.register_buffer("relative_position_index_SA", relative_position_index_SA) + self.register_buffer("relative_position_index_OCA", relative_position_index_OCA) + + # ------------------------- 1, shallow feature extraction ------------------------- # + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + # ------------------------- 2, deep feature extraction ------------------------- # + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Hybrid Attention Groups (RHAG) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RHAG( + dim=embed_dim, + input_resolution=(patches_resolution[0], patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + compress_ratio=compress_ratio, + squeeze_factor=squeeze_factor, + conv_scale=conv_scale, + overlap_ratio=overlap_ratio, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection, + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == "1conv": + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == "identity": + self.conv_after_body = nn.Identity() + + # ------------------------- 3, high quality image reconstruction ------------------------- # + if self.upsampler == "pixelshuffle": + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) + ) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def calculate_rpi_sa(self): + # calculate relative position index for SA + coords_h = torch.arange(self.window_size) + coords_w = torch.arange(self.window_size) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size - 1 + relative_coords[:, :, 0] *= 2 * self.window_size - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + def calculate_rpi_oca(self): + # calculate relative position index for OCA + window_size_ori = self.window_size + window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size) + + coords_h = torch.arange(window_size_ori) + coords_w = torch.arange(window_size_ori) + coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, ws, ws + coords_ori_flatten = torch.flatten(coords_ori, 1) # 2, ws*ws + + coords_h = torch.arange(window_size_ext) + coords_w = torch.arange(window_size_ext) + coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, wse, wse + coords_ext_flatten = torch.flatten(coords_ext, 1) # 2, wse*wse + + relative_coords = coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None] # 2, ws*ws, wse*wse + + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # ws*ws, wse*wse, 2 + relative_coords[:, :, 0] += window_size_ori - window_size_ext + 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1 + + relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1 + relative_position_index = relative_coords.sum(-1) + return relative_position_index + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + h, w = x_size + img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + @torch.jit.ignore + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"relative_position_bias_table"} + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + + # Calculate attention mask and relative position index in advance to speed up inference. + # The original code is very time-cosuming for large window size. + attn_mask = self.calculate_mask(x_size).to(x.device) + params = { + "attn_mask": attn_mask, + "rpi_sa": self.relative_position_index_SA, + "rpi_oca": self.relative_position_index_OCA, + } + + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size, params) + x = self.norm(x) # b seq_len c + # return x + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == "pixelshuffle": + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + + x = x / self.img_range + self.mean + + return x + + +class TileRefinement(HAT): + """ + Tile refinement Module + + output feature and final upsampled image + """ + + def __init__( + self, + img_size=64, + patch_size=1, + in_chans=3, + embed_dim=96, + depths=(6, 6, 6, 6), + num_heads=(6, 6, 6, 6), + window_size=7, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + overlap_ratio=0.5, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0, + attn_drop_rate=0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + upscale=2, + img_range=1, + upsampler="", + resi_connection="1conv", + **kwargs, + ): + super().__init__( + img_size, + patch_size, + in_chans, + embed_dim, + depths, + num_heads, + window_size, + compress_ratio, + squeeze_factor, + conv_scale, + overlap_ratio, + mlp_ratio, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + drop_path_rate, + norm_layer, + ape, + patch_norm, + use_checkpoint, + upscale, + img_range, + upsampler, + resi_connection, + **kwargs, + ) + + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + def forward(self, x): + B = x.shape[0] + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == "pixelshuffle": + # for classical SR + x = self.conv_first(x) + fea = self.forward_features(x) # 1, C, 64, 64 + x = self.conv_after_body(fea) + x + x = self.conv_before_upsample(x) + x = self.upsample(x) + x = self.conv_last(x) + # return x, fea + + x = x / self.img_range + self.mean + + return x, fea diff --git a/models/experimental/SSR/reference/SSR/model/tile_selection.py b/models/experimental/SSR/reference/SSR/model/tile_selection.py new file mode 100644 index 000000000000..31c5d50e0ff3 --- /dev/null +++ b/models/experimental/SSR/reference/SSR/model/tile_selection.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + + +# Reference: +# https://github.com/destiny301/SSR + +from models.experimental.SSR.reference.SSR.model.net_blocks import BasicLayer, PatchEmbed, Mlp, PatchMerging + +import torch, math +import torch.nn as nn + + +class mask_token_inference(nn.Module): + r"""cross-attention between classfification token and image representation""" + + def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0): + super().__init__() + + # self.norm = nn.LayerNorm(dim, dtype=torch.bfloat16) + self.norm = nn.LayerNorm(dim) + self.num_heads = num_heads + head_dim = dim // num_heads + + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sigmoid = nn.Sigmoid() + + def forward(self, fea): + B, N, C = fea.shape + x = self.norm(fea) + T_s, F_s = x[:, 0, :].unsqueeze(1), x[:, 1:, :] + # T_s [B, 1, c] F_s [B, h*w, c] + + q = self.q(F_s).reshape(B, N - 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = self.k(T_s).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(T_s).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + attn = self.sigmoid(attn) + attn = self.attn_drop(attn) + + infer_fea = (attn @ v).transpose(1, 2).reshape(B, N - 1, C) + infer_fea = self.proj(infer_fea) + infer_fea = self.proj_drop(infer_fea) + + infer_fea = infer_fea + fea[:, 1:, :] + return infer_fea + + +class TileSelection(nn.Module): + r"""Tile Selection Module + Split image into non-overlapping tiles, 4*4/8*4..., classify each tile + + Args: + args. + num_cls: number of output classes (BCELoss-->1) + """ + + def __init__(self, args, num_cls) -> None: + super().__init__() + self.token_size = args.token_size # the size of the final encoded representations, i.e. number of tiles (4*4) + self.patch_embed = PatchEmbed(img_size=args.imgsz, patch_size=args.patchsz, in_chans=3) + self.num_layers = int(math.log2((args.imgsz // args.patchsz) // args.token_size)) + + # encoder + patches_resolution = self.patch_embed.patches_resolution + depths = [2, 2, 2, 2, 2, 2] + dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))] # stochastic depth decay rule + + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(args.dim * 2**i_layer), + input_resolution=(patches_resolution[0] // (2**i_layer), patches_resolution[1] // (2**i_layer)), + depth=2, + num_heads=3, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=nn.LayerNorm, + downsample=PatchMerging if (i_layer < self.num_layers) else None, + use_checkpoint=False, + ) + self.layers.append(layer) + + self.norm3 = nn.LayerNorm(args.dim * 2 ** (self.num_layers)) + self.norm2 = nn.LayerNorm(args.dim * 2 ** (self.num_layers - 1)) + self.norm1 = nn.LayerNorm(args.dim * 2 ** (self.num_layers - 2)) + + # classifier + self.mask_token = nn.Embedding(1, 96 * (2**self.num_layers)) + self.fea_mlp3 = Mlp( + in_features=96 * (2**self.num_layers), + hidden_features=96 * (2**self.num_layers), + out_features=96 * (2**self.num_layers), + ) + + self.mask_pre3 = mask_token_inference(dim=96 * (2**self.num_layers), num_heads=1) + self.mlp_norm3 = nn.LayerNorm(96 * (2**self.num_layers)) + self.mlp3 = Mlp(in_features=96 * (2**self.num_layers), hidden_features=96, out_features=96) + self.linear3 = nn.Linear(96, num_cls) + + self.fea_mlp2 = Mlp( + in_features=96 * (2 ** (self.num_layers - 1)), + hidden_features=96 * (2**self.num_layers), + out_features=96 * (2**self.num_layers), + ) + self.mask_pre2 = mask_token_inference(dim=96 * (2**self.num_layers), num_heads=1) + self.mlp_norm2 = nn.LayerNorm(96 * (2**self.num_layers)) + self.mlp2 = Mlp(in_features=96 * (2**self.num_layers), hidden_features=96, out_features=96) + self.linear2 = nn.Linear(96, num_cls) + + self.fea_mlp1 = Mlp( + in_features=96 * (2 ** (self.num_layers - 2)), + hidden_features=96 * (2**self.num_layers), + out_features=96 * (2**self.num_layers), + ) + self.mask_pre1 = mask_token_inference(dim=96 * (2**self.num_layers), num_heads=1) + self.mlp_norm1 = nn.LayerNorm(96 * (2**self.num_layers)) + self.mlp1 = Mlp(in_features=96 * (2**self.num_layers), hidden_features=96, out_features=96) + self.linear1 = nn.Linear(96, num_cls) + + def forward(self, x): + B, C, H, W = x.shape + x = self.patch_embed(x) + + # encoder + x_downsample = [] + + for layer in self.layers: + x_downsample.append(x) + x = layer(x) + x3 = self.norm3(x) + x2 = self.norm2(x_downsample[-1]) + x1 = self.norm1(x_downsample[-2]) + + # decoder + mask_tokens = self.mask_token.weight + mask_tokens = mask_tokens.unsqueeze(0).expand(B, -1, -1) + + # predict 4*4 representations, for 64*64 tile + fea_3 = torch.cat((mask_tokens, self.fea_mlp3(x3)), dim=1) + + mask_tokens = fea_3[:, 0, :].unsqueeze(1) + mask_3 = self.mask_pre3(fea_3) # [B, 16, 96*32] + mask_3 = self.mlp3(self.mlp_norm3(mask_3)) + mask_3 = self.linear3(mask_3) + B, N, C = mask_3.shape + mask_3 = mask_3.transpose(1, 2).reshape(B, C, self.token_size, self.token_size) + + # predict 8*8 representations, for 32*32 tile + fea_2 = torch.cat((mask_tokens, self.fea_mlp2(x2)), dim=1) + + mask_tokens = fea_2[:, 0, :].unsqueeze(1) + mask_2 = self.mask_pre2(fea_2) + mask_2 = self.mlp2(self.mlp_norm2(mask_2)) + mask_2 = self.linear2(mask_2) + mask_2 = mask_2.transpose(1, 2).reshape(B, C, self.token_size * 2, self.token_size * 2) + + # predict 16*16 representations, for 16*16 tile + fea_1 = torch.cat((mask_tokens, self.fea_mlp1(x1)), dim=1) + + mask_tokens = fea_1[:, 0, :].unsqueeze(1) + mask_1 = self.mask_pre1(fea_1) + mask_1 = self.mlp1(self.mlp_norm1(mask_1)) + mask_1 = self.linear1(mask_1) + mask_1 = mask_1.transpose(1, 2).reshape(B, C, self.token_size * 4, self.token_size * 4) + + return mask_3, mask_2, mask_1 diff --git a/models/experimental/SSR/tests/common/test_mlp.py b/models/experimental/SSR/tests/common/test_mlp.py new file mode 100644 index 000000000000..c0c74fb3c68e --- /dev/null +++ b/models/experimental/SSR/tests/common/test_mlp.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest + +import ttnn + +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.net_blocks import Mlp +from models.experimental.SSR.tt.common import TTMlp + +from ttnn.model_preprocessing import preprocess_model_parameters, preprocess_linear_bias, preprocess_linear_weight +from models.utility_functions import ( + tt2torch_tensor, +) +from tests.ttnn.utils_for_testing import check_with_pcc + + +def create_mlp_preprocessor(device, weight_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + parameters = {} + if hasattr(torch_model, "fc1") and hasattr(torch_model, "fc2"): # MLP model + parameters["fc1"] = {} + parameters["fc2"] = {} + + # Preprocess fc1 layer parameters + parameters["fc1"]["weight"] = preprocess_linear_weight(torch_model.fc1.weight, dtype=weight_dtype) + parameters["fc1"]["bias"] = preprocess_linear_bias(torch_model.fc1.bias, dtype=weight_dtype) + + # Preprocess fc2 layer parameters + parameters["fc2"]["weight"] = preprocess_linear_weight(torch_model.fc2.weight, dtype=weight_dtype) + parameters["fc2"]["bias"] = preprocess_linear_bias(torch_model.fc2.bias, dtype=weight_dtype) + + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "in_features, hidden_features, out_features, input_shape", + ( + (3072, 3072, 3072, (3, 16, 3072)), # TTTileSelection -> fea_mlp3 + (3072, 96, 96, (3, 16, 3072)), # TTTileSelection -> mlp3 + (96, 384, 96, (3, 16384, 96)), # TTSwinTransformerBlock[0], TTSwinTransformerBlock[1] -> mlp + (192, 768, 192, (3, 4096, 192)), # TTSwinTransformerBlock[2], TTSwinTransformerBlock[3] -> mlp + (384, 1536, 384, (3, 1024, 384)), # TTSwinTransformerBlock[4], TTSwinTransformerBlock[5] -> mlp + (768, 3072, 768, (3, 256, 768)), # TTSwinTransformerBlock[6], TTSwinTransformerBlock[7] -> mlp + (1536, 6144, 1536, (3, 64, 1536)), # TTSwinTransformerBlock[6], TTSwinTransformerBlock[7] -> mlp + (180, 360, None, (1, 4096, 180)), # TR + ), +) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat8_b]) +def test_mlp(device, in_features, hidden_features, out_features, input_shape, input_dtype, weight_dtype): + x = torch.randn(input_shape) + + ref_layer = Mlp( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + ) + + ref_output = ref_layer(x) + + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_layer, + custom_preprocessor=create_mlp_preprocessor(device, weight_dtype), + device=device, + ) + + tt_layer = TTMlp( + device, + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + parameters=parameters, + dtype=input_dtype, + ) + tt_input = ttnn.from_torch(x, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + tt_input = ttnn.to_memory_config(tt_input, memory_config=ttnn.L1_MEMORY_CONFIG) + tt_output = tt_layer(tt_input) + tt_torch_output = tt2torch_tensor(tt_output) + + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.99) + + logger.info(f"PCC: {pcc_message}") + + if does_pass: + logger.info("SSR MLP Passed!") + else: + logger.warning("SSR MLP Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/test_ssr.py b/models/experimental/SSR/tests/test_ssr.py new file mode 100644 index 000000000000..c074b26df1b8 --- /dev/null +++ b/models/experimental/SSR/tests/test_ssr.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +from loguru import logger + +# Import reference and TTNN models +from models.experimental.SSR.reference.SSR.model.ssr import SSR, SSR_wo_conv +from models.experimental.SSR.tt.ssr import TTSSR, TTSSR_wo_conv +from models.experimental.SSR.tests.tile_refinement.test_upsample import create_upsample_preprocessor +from models.experimental.SSR.tests.tile_selection.test_tile_selection import create_tile_selection_preprocessor +from models.experimental.SSR.tests.tile_refinement.test_tile_refinement import create_tile_refinement_preprocessor +from models.experimental.SSR.tests.tile_refinement.test_HAB import create_relative_position_index +from models.experimental.SSR.reference.SSR.model.net_blocks import window_reverse +from models.experimental.SSR.tests.tile_refinement.test_tile_refinement import get_precision_config + + +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import tt2torch_tensor +from tests.ttnn.utils_for_testing import check_with_pcc + + +def create_ssr_preprocessor(device, args, num_cls, depth, weight_dtype=ttnn.bfloat16): + """Custom preprocessor for SSR model""" + + def custom_preprocessor(torch_model, name, ttnn_module_args): + parameters = {} + + if isinstance(torch_model, SSR) or isinstance(torch_model, SSR_wo_conv): + # Preprocess tile selection model + select_params = preprocess_model_parameters( + initialize_model=lambda: torch_model.select_model, + custom_preprocessor=create_tile_selection_preprocessor(device, weight_dtype=weight_dtype), + device=device, + ) + parameters["select_model"] = select_params + + # Preprocess tile refinement model + + rpi_sa = create_relative_position_index((16, 16)) + + attn_mask = None + + # Create RPI for OCAB + overlap_win_size = int(16 * 0.5) + 16 + rpi_oca = torch.zeros((16 * 16, overlap_win_size * overlap_win_size), dtype=torch.long) + + # Create params dictionary + + tt_rpi_sa = ttnn.from_torch(rpi_sa, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32) + + tt_rpi_oca = ttnn.from_torch(rpi_oca, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.uint32) + + forward_params = {"rpi_sa": tt_rpi_sa, "attn_mask": attn_mask, "rpi_oca": tt_rpi_oca} + sr_params = preprocess_model_parameters( + initialize_model=lambda: torch_model.sr_model, + custom_preprocessor=create_tile_refinement_preprocessor( + device, forward_params, window_size=16, rpi_sa=rpi_sa, depth=depth + ), + device=device, + ) + parameters["sr_model"] = sr_params + + # Preprocess conv layers + conv_layers = ["conv_first", "conv_last"] + for conv_name in conv_layers: + if hasattr(torch_model, conv_name): + conv_layer = getattr(torch_model, conv_name) + parameters[conv_name] = { + "weight": ttnn.from_torch(conv_layer.weight, dtype=ttnn.bfloat16), + "bias": ttnn.from_torch(conv_layer.bias.reshape(1, 1, 1, -1), dtype=ttnn.bfloat16), + } + + # Preprocess conv_before_upsample (Sequential layer) + if hasattr(torch_model, "conv_before_upsample"): + conv_layer = torch_model.conv_before_upsample[0] # Conv2d layer + parameters["conv_before_upsample"] = { + "weight": ttnn.from_torch(conv_layer.weight, dtype=ttnn.bfloat16), + "bias": ttnn.from_torch(conv_layer.bias.reshape(1, 1, 1, -1), dtype=ttnn.bfloat16), + } + + # Preprocess upsample + if hasattr(torch_model, "upsample"): + upsample_params = preprocess_model_parameters( + initialize_model=lambda: torch_model.upsample, + custom_preprocessor=create_upsample_preprocessor(device), + device=device, + ) + parameters["upsample"] = upsample_params + + return parameters + + return custom_preprocessor + + +class MockArgs: + """Mock args class for testing""" + + def __init__(self): + self.token_size = 4 + self.imgsz = 256 + self.patchsz = 2 + self.pretrain = False + self.ckpt = None + self.dim = 96 + + +@pytest.mark.parametrize( + "input_shape, num_cls, with_conv, depth, num_heads", + [ + # ((1, 3, 256, 256), 1, True), + ((1, 3, 256, 256), 1, False, [1], [1]), + ((1, 3, 256, 256), 1, False, [6, 6, 6, 6, 6, 6], [6, 6, 6, 6, 6, 6]), + ], +) +@pytest.mark.parametrize( + "precision_config", + [ + lambda: get_precision_config("performance"), + lambda: get_precision_config("accuracy"), + ], + ids=["performance", "accuracy"], +) +def test_ssr_model(input_shape, num_cls, with_conv, depth, num_heads, precision_config): + """Test TTSSR model against PyTorch reference""" + # Get data types from precision configuration + + torch.manual_seed(0) + + input_dtype, weight_dtype = precision_config() + + # Create input tensor + x = torch.randn(input_shape) + _, _, H, W = x.shape + + # Create mock args + args = MockArgs() + + # Create reference PyTorch model + if with_conv: + ref_model = SSR(args, num_cls, depth, num_heads) + else: + ref_model = SSR_wo_conv(args, num_cls, depth, num_heads) + ref_model.eval() + + # Get reference output + with torch.no_grad(): + ref_sr, ref_patch_fea3, _, _ = ref_model(x) + # Open TTNN device with larger L1 cache to handle memory requirements + device = ttnn.open_device(device_id=0, l1_small_size=32768) # 128KB instead of 32KB + + memory_config = ttnn.L1_MEMORY_CONFIG + try: + # Preprocess model parameters + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_model, + custom_preprocessor=create_ssr_preprocessor(device, args, num_cls, depth, weight_dtype), + device=device, + ) + # Create TTNN model + if with_conv: + tt_model = TTSSR( + device=device, + parameters=parameters, + args=args, + num_cls=num_cls, + depth=depth, + num_heads=num_heads, + memory_config=memory_config, + ) + else: + tt_model = TTSSR_wo_conv( + device=device, + parameters=parameters, + args=args, + num_cls=num_cls, + depth=depth, + num_heads=num_heads, + dtype=input_dtype, + memory_config=memory_config, + ) + + # Convert input to TTNN tensor + tt_input = ttnn.from_torch(x, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + + # Run TTNN model + tt_sr, tt_patch_fea3 = tt_model(tt_input) + + # Convert back to torch tensors + tt_torch_sr = tt2torch_tensor(tt_sr) + tt_torch_patch_fea3 = tt2torch_tensor(tt_patch_fea3) + tt_torch_sr = tt_torch_sr.permute(0, 3, 1, 2) + + if not with_conv: + _, _, H, W = x.shape + tt_torch_sr = window_reverse(tt_torch_sr.permute(0, 2, 3, 1), window_size=H, H=H * 4, W=W * 4) + tt_torch_sr = tt_torch_sr.permute(0, 3, 1, 2) + + # Compare outputs + sr_pass, sr_pcc_message = check_with_pcc(ref_sr, tt_torch_sr, 0.90) + fea3_pass, fea3_pcc_message = check_with_pcc(ref_patch_fea3, tt_torch_patch_fea3, 0.90) + logger.info(f"sr_pcc: {sr_pcc_message}") + logger.info(f"fea3_pcc: {fea3_pcc_message}") + + all_pass = sr_pass and fea3_pass + + if all_pass: + logger.info("TTSSR Test Passed!") + else: + logger.warning("TTSSR Test Failed!") + + assert sr_pass, f"SR output failed PCC check: {sr_pcc_message}" + assert fea3_pass, f"Patch fea3 failed PCC check: {fea3_pcc_message}" + + finally: + ttnn.close_device(device) diff --git a/models/experimental/SSR/tests/tile_refinement/test_CAB.py b/models/experimental/SSR/tests/tile_refinement/test_CAB.py new file mode 100644 index 000000000000..0fd3e24d6f82 --- /dev/null +++ b/models/experimental/SSR/tests/tile_refinement/test_CAB.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.tile_refinement import CAB +from models.experimental.SSR.tt.tile_refinement import TTCAB +from tests.ttnn.utils_for_testing import check_with_pcc +from models.experimental.SSR.tests.tile_refinement.test_channel_attention import create_channel_attention_preprocessor + + +def create_cab_preprocessor(device, weight_dtype=ttnn.bfloat16, input_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + params = {} + + # Extract the sequential layers from CAB + cab_layers = list(torch_model.cab.children()) + conv1 = cab_layers[0] # First Conv2d layer + conv2 = cab_layers[2] # Second Conv2d layer (after GELU) + channel_attention = cab_layers[3] # ChannelAttention module + + # Preprocess first convolution (3x3) + params["conv1"] = { + "weight": ttnn.from_torch(conv1.weight, dtype=weight_dtype, layout=ttnn.ROW_MAJOR_LAYOUT), + "bias": ttnn.from_torch(conv1.bias.reshape(1, 1, 1, -1), dtype=weight_dtype, layout=ttnn.ROW_MAJOR_LAYOUT), + } + + # Preprocess second convolution (3x3) + params["conv2"] = { + "weight": ttnn.from_torch(conv2.weight, dtype=weight_dtype, layout=ttnn.ROW_MAJOR_LAYOUT), + "bias": ttnn.from_torch(conv2.bias.reshape(1, 1, 1, -1), dtype=weight_dtype, layout=ttnn.ROW_MAJOR_LAYOUT), + } + + # Preprocess channel attention using existing preprocessor + channel_attention_preprocessor = create_channel_attention_preprocessor( + device, weight_dtype=weight_dtype, input_dtype=input_dtype + ) + params["channel_attention"] = channel_attention_preprocessor( + channel_attention, "channel_attention", ttnn_module_args + ) + + return params + + return custom_preprocessor + + +@pytest.mark.parametrize( + "batch_size, num_feat, height, width, compress_ratio, squeeze_factor", + [ + (1, 180, 64, 64, 3, 30), # SSR config + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat16]) +def test_cab_block( + device, batch_size, num_feat, height, width, compress_ratio, squeeze_factor, input_dtype, weight_dtype +): + torch.manual_seed(0) + + # Create reference model + ref_model = CAB(num_feat=num_feat, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor) + ref_model.eval() + + # Create input tensor + input_tensor = torch.randn(batch_size, num_feat, height, width) + + # Reference forward pass + with torch.no_grad(): + ref_output = ref_model(input_tensor) + + parameters = ttnn.model_preprocessing.preprocess_model( + initialize_model=lambda: ref_model, + custom_preprocessor=create_cab_preprocessor(device, weight_dtype, input_dtype), + device=device, + run_model=lambda model: model(input_tensor), + ) + + memory_config = ttnn.L1_MEMORY_CONFIG + tt_model = TTCAB( + device=device, + parameters=parameters, + num_feat=num_feat, + compress_ratio=compress_ratio, + squeeze_factor=squeeze_factor, + memory_config=memory_config, + dtype=input_dtype, + ) + + tt_input = ttnn.from_torch( + input_tensor.permute(0, 2, 3, 1), + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=input_dtype, + memory_config=memory_config, + ) + + # TTNN forward pass + tt_output = tt_model(tt_input) + + tt_torch_output = ttnn.to_torch(tt_output) + tt_torch_output = tt_torch_output.permute(0, 3, 1, 2) # NHWC -> NCHW + + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.97) + logger.info(f"pcc: {pcc_message}") + + if does_pass: + logger.info("CAB Block Passed!") + else: + logger.warning("CAB Block Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_refinement/test_HAB.py b/models/experimental/SSR/tests/tile_refinement/test_HAB.py new file mode 100644 index 000000000000..ac32069a1eb5 --- /dev/null +++ b/models/experimental/SSR/tests/tile_refinement/test_HAB.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.tile_refinement import HAB +from models.experimental.SSR.tt.tile_refinement import TTHAB +from ttnn.model_preprocessing import preprocess_linear_bias, preprocess_linear_weight +from tests.ttnn.utils_for_testing import check_with_pcc +from models.experimental.SSR.tests.tile_refinement.test_window_attn_tr import create_window_attention_preprocessor +from models.experimental.SSR.tests.tile_refinement.test_CAB import create_cab_preprocessor +from models.experimental.SSR.tests.common.test_mlp import create_mlp_preprocessor + + +def create_hab_preprocessor(device, window_size, rpi, weight_dtype=ttnn.bfloat16, input_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + params = {} + + # Norm layers + params["norm1"] = { + "weight": preprocess_linear_weight(torch_model.norm1.weight, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT), + "bias": preprocess_linear_bias(torch_model.norm1.bias, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT), + } + params["norm2"] = { + "weight": preprocess_linear_weight(torch_model.norm2.weight, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT), + "bias": preprocess_linear_bias(torch_model.norm2.bias, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT), + } + relative_position_bias = torch_model.attn.relative_position_bias_table[rpi.view(-1)].view( + window_size * window_size, window_size * window_size, -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + params["relative_position_bias"] = ttnn.from_torch( + relative_position_bias.unsqueeze(0), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT + ) + # Window attention parameters + window_attention_preprocessor = create_window_attention_preprocessor( + device, (window_size, window_size), rpi, weight_dtype=weight_dtype + ) + params["attn"] = window_attention_preprocessor(torch_model.attn, "attn", ttnn_module_args) + + # Conv block parameters + cab_preprocessor = create_cab_preprocessor(device, weight_dtype=ttnn.bfloat16, input_dtype=input_dtype) + params["conv_block"] = cab_preprocessor(torch_model.conv_block, "conv_block", ttnn_module_args) + + # MLP parameters + mlp_preprocessor = create_mlp_preprocessor(device, weight_dtype=weight_dtype) + params["mlp"] = mlp_preprocessor(torch_model.mlp, "mlp", ttnn_module_args) + + # Conv scale + params["conv_scale"] = torch_model.conv_scale + + return params + + return custom_preprocessor + + +def create_relative_position_index(window_size): + """Create relative position index for window attention""" + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += window_size[0] - 1 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + return relative_coords.sum(-1) + + +@pytest.mark.parametrize( + "batch_size, height, width, dim, num_heads, window_size, shift_size, mlp_ratio", + [ + # SSR configurations + # (1, 64, 64, 180, 6, 16, 8, 2), # With shift + (1, 64, 64, 180, 6, 16, 0, 2), # Without shift + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat8_b]) +def test_hab_block( + device, batch_size, height, width, dim, num_heads, window_size, shift_size, mlp_ratio, input_dtype, weight_dtype +): + torch.manual_seed(0) + + # Create reference model + ref_model = HAB( + dim=dim, + input_resolution=(height, width), + num_heads=num_heads, + window_size=window_size, + shift_size=shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + ) + ref_model.eval() + + # Create input tensors + input_tensor = torch.randn(batch_size, height * width, dim) + x_size = (height, width) + + # Create relative position index + rpi_sa = create_relative_position_index((window_size, window_size)) + + # Create attention mask for shifted windows + if shift_size > 0: + img_mask = torch.zeros((1, height, width, 1)) + h_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) + w_slices = (slice(0, -window_size), slice(-window_size, -shift_size), slice(-shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # Create attention mask + mask_windows = img_mask.view(1, height // window_size, window_size, width // window_size, window_size, 1) + mask_windows = mask_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float("-inf")).masked_fill(attn_mask == 0, float("0.0")) + else: + attn_mask = None + + # Reference forward pass + with torch.no_grad(): + ref_output = ref_model(input_tensor, x_size, rpi_sa, attn_mask) + + # Create TTNN model + parameters = ttnn.model_preprocessing.preprocess_model( + initialize_model=lambda: ref_model, + custom_preprocessor=create_hab_preprocessor( + device, window_size, rpi_sa, weight_dtype=weight_dtype, input_dtype=input_dtype + ), + device=device, + run_model=lambda model: model(input_tensor, x_size, rpi_sa, attn_mask), + ) + + memory_config = ttnn.L1_MEMORY_CONFIG + + tt_model = TTHAB( + device=device, + parameters=parameters, + dim=dim, + input_resolution=(height, width), + num_heads=num_heads, + window_size=window_size, + shift_size=shift_size, + mlp_ratio=mlp_ratio, + memory_config=memory_config, + dtype=input_dtype, + ) + + # Convert inputs to TTNN format + tt_input = ttnn.from_torch( + input_tensor, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype, memory_config=memory_config + ) + + tt_rpi = ttnn.from_torch( + rpi_sa, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, memory_config=memory_config + ) + + tt_attn_mask = None + if attn_mask is not None: + tt_attn_mask = ttnn.from_torch( + attn_mask, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16, memory_config=memory_config + ) + + # TTNN forward pass + tt_output = tt_model(tt_input, x_size, tt_rpi, tt_attn_mask) + + # Convert back to PyTorch format + tt_torch_output = ttnn.to_torch(tt_output) + + # Compare outputs + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.95) + + logger.info(f"pcc: {pcc_message}") + if does_pass: + logger.info("HAB Block Passed!") + else: + logger.warning("HAB Block Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_refinement/test_OCAB.py b/models/experimental/SSR/tests/tile_refinement/test_OCAB.py new file mode 100644 index 000000000000..d4c17ac3bb0a --- /dev/null +++ b/models/experimental/SSR/tests/tile_refinement/test_OCAB.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn as nn +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import tt2torch_tensor +from tests.ttnn.utils_for_testing import check_with_pcc +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.tile_refinement import OCAB +from models.experimental.SSR.tt.tile_refinement import TTOCAB + + +def create_ocab_preprocessor(device, tile_size=32, weight_dtype=ttnn.bfloat16, input_dtype=ttnn.bfloat16): + """Create custom preprocessor for OCAB parameters""" + + def custom_preprocessor(model, name): + parameters = {} + if isinstance(model, OCAB): + # Layer norm parameters + dim = model.norm1.weight.size(0) + padded_dim = ((dim + tile_size - 1) // tile_size) * tile_size + + norm1_weight_padded = torch.nn.functional.pad(model.norm1.weight, (0, padded_dim - dim)) + norm1_bias_padded = torch.nn.functional.pad(model.norm1.bias, (0, padded_dim - dim)) + + norm1_weight = norm1_weight_padded.view(1, 1, padded_dim // tile_size, tile_size) + norm1_bias = norm1_bias_padded.view(1, 1, padded_dim // tile_size, tile_size) + + parameters["norm1"] = { + "weight": ttnn.from_torch( + norm1_weight, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device + ), + "bias": ttnn.from_torch(norm1_bias, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device), + } + + # QKV linear layer with padded heads + qkv_weight = model.qkv.weight.T # [embed_dim, 3*embed_dim] + num_heads = model.num_heads + head_size = qkv_weight.shape[1] // (3 * num_heads) + padded_head_size = ((head_size + tile_size - 1) // tile_size) * tile_size + + if padded_head_size != head_size: + # Pad each head separately + qkv_chunks = torch.split(qkv_weight, head_size, dim=1) + qkv_weight_padded = torch.cat( + [torch.nn.functional.pad(chunk, (0, padded_head_size - head_size)) for chunk in qkv_chunks], dim=1 + ) + + if model.qkv.bias is not None: + qkv_bias_chunks = torch.split(model.qkv.bias, head_size, dim=0) + qkv_bias_padded = torch.cat( + [ + torch.nn.functional.pad(chunk, (0, padded_head_size - head_size)) + for chunk in qkv_bias_chunks + ], + dim=0, + ) + else: + qkv_bias_padded = None + else: + qkv_weight_padded = qkv_weight + qkv_bias_padded = model.qkv.bias + + parameters["qkv"] = { + "weight": ttnn.from_torch( + qkv_weight_padded, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device + ), + "bias": ttnn.from_torch(qkv_bias_padded, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device) + if qkv_bias_padded is not None + else None, + } + + # Relative position bias table + parameters["relative_position_bias_table"] = ttnn.from_torch( + model.relative_position_bias_table, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + + # Output projection + proj_weight = model.proj.weight.T + parameters["proj"] = { + "weight": ttnn.from_torch(proj_weight, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device), + "bias": ttnn.from_torch(model.proj.bias, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device), + } + + # Layer norm 2 + norm2_weight_padded = torch.nn.functional.pad(model.norm2.weight, (0, padded_dim - dim)) + norm2_bias_padded = torch.nn.functional.pad(model.norm2.bias, (0, padded_dim - dim)) + + norm2_weight = norm2_weight_padded.view(1, 1, padded_dim // tile_size, tile_size) + norm2_bias = norm2_bias_padded.view(1, 1, padded_dim // tile_size, tile_size) + + parameters["norm2"] = { + "weight": ttnn.from_torch( + norm2_weight, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device + ), + "bias": ttnn.from_torch(norm2_bias, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device), + } + + # MLP parameters + parameters["mlp"] = { + "fc1": { + "weight": ttnn.from_torch( + model.mlp.fc1.weight.T, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device + ), + "bias": ttnn.from_torch( + model.mlp.fc1.bias, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device + ), + }, + "fc2": { + "weight": ttnn.from_torch( + model.mlp.fc2.weight.T, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device + ), + "bias": ttnn.from_torch( + model.mlp.fc2.bias, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device + ), + }, + } + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "dim, input_resolution, window_size, overlap_ratio, num_heads, input_shape", + ((180, (64, 64), 16, 0.5, 6, (1, 4096, 180)),), +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat8_b]) +def test_ocab( + device, dim, input_resolution, window_size, overlap_ratio, num_heads, input_shape, input_dtype, weight_dtype +): + x = torch.randn(input_shape) + + # Create reference OCAB layer + ref_layer = OCAB( + dim=dim, + input_resolution=input_resolution, + window_size=window_size, + overlap_ratio=overlap_ratio, + num_heads=num_heads, + qkv_bias=True, + qk_scale=None, + mlp_ratio=2, + norm_layer=nn.LayerNorm, + ) + + h, w = input_resolution + x_size = (h, w) + overlap_win_size = int(window_size * overlap_ratio) + window_size + rpi = torch.zeros((window_size * window_size, overlap_win_size * overlap_win_size), dtype=torch.long) + + ref_output = ref_layer(x, x_size, rpi) + + ttnn.synchronize_device(device) + + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_layer, + custom_preprocessor=create_ocab_preprocessor(device, weight_dtype=weight_dtype, input_dtype=input_dtype), + device=device, + ) + + tt_layer = TTOCAB( + device=device, + dim=dim, + input_resolution=input_resolution, + window_size=window_size, + overlap_ratio=overlap_ratio, + num_heads=num_heads, + parameters=parameters, + dtype=input_dtype, + ) + + tt_input = ttnn.from_torch( + x, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b, memory_config=ttnn.L1_MEMORY_CONFIG + ) + tt_rpi = ttnn.from_torch(rpi, device=device, layout=ttnn.TILE_LAYOUT) + tt_output = tt_layer.forward(tt_input, x_size, tt_rpi) + tt_torch_output = tt2torch_tensor(tt_output) + + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.99) + + logger.info(f"pcc: {pcc_message}") + + if does_pass: + logger.info("OCAB Layer Passed!") + else: + logger.warning("OCAB Layer Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_refinement/test_RHAG.py b/models/experimental/SSR/tests/tile_refinement/test_RHAG.py new file mode 100644 index 000000000000..170beead56c8 --- /dev/null +++ b/models/experimental/SSR/tests/tile_refinement/test_RHAG.py @@ -0,0 +1,222 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn as nn +import ttnn +from loguru import logger + +from models.experimental.SSR.tt.tile_refinement import TTRHAG +from models.experimental.SSR.reference.SSR.model.tile_refinement import RHAG +from models.experimental.SSR.tests.tile_refinement.test_HAB import create_relative_position_index +from models.experimental.SSR.tests.tile_refinement.test_atten_blocks import create_atten_blocks_preprocessor +from models.experimental.SSR.tests.tile_refinement.test_patch_embed_tile_refinement import ( + create_patch_embed_preprocessor_conv, +) + +from tests.ttnn.utils_for_testing import check_with_pcc +from ttnn.model_preprocessing import preprocess_model_parameters + + +def create_rhag_preprocessor(device, depth, window_size, rpi_sa, weight_dtype=ttnn.bfloat16, input_dtype=ttnn.bfloat16): + """Preprocessor for RHAG that handles all sub-components by importing existing preprocessors""" + + def custom_preprocessor(torch_model, name, ttnn_module_args): + params = {} + + # Import and use AttenBlocks preprocessor + atten_blocks_preprocessor = create_atten_blocks_preprocessor( + device, depth, window_size, rpi_sa, weight_dtype=weight_dtype, input_dtype=input_dtype + ) + params["residual_group"] = atten_blocks_preprocessor( + torch_model.residual_group, "residual_group", ttnn_module_args + ) + + # Preprocess conv layer parameters (if 1conv) + if hasattr(torch_model, "conv") and hasattr(torch_model.conv, "weight"): + conv_config = ttnn.Conv2dConfig(weights_dtype=ttnn.bfloat16) + params["conv"] = { + "weight": ttnn.prepare_conv_weights( + weight_tensor=ttnn.from_torch(torch_model.conv.weight, dtype=ttnn.bfloat16), + input_memory_config=ttnn.DRAM_MEMORY_CONFIG, + input_layout=ttnn.TILE_LAYOUT, + weights_format="OIHW", + in_channels=torch_model.conv.in_channels, + out_channels=torch_model.conv.out_channels, + batch_size=1, + input_height=torch_model.input_resolution[0], + input_width=torch_model.input_resolution[1], + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + dilation=(1, 1), + has_bias=True, + groups=1, + device=device, + input_dtype=input_dtype, + conv_config=conv_config, + ), + "bias": ttnn.prepare_conv_bias( + bias_tensor=ttnn.from_torch(torch_model.conv.bias.reshape(1, 1, 1, -1), dtype=ttnn.bfloat16), + input_memory_config=ttnn.DRAM_MEMORY_CONFIG, + input_layout=ttnn.TILE_LAYOUT, + in_channels=torch_model.conv.in_channels, + out_channels=torch_model.conv.out_channels, + batch_size=1, + input_height=torch_model.input_resolution[0], + input_width=torch_model.input_resolution[1], + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + dilation=(1, 1), + groups=1, + device=device, + input_dtype=input_dtype, + conv_config=conv_config, + ), + } + + if hasattr(torch_model, "patch_embed"): + patch_embed_params = preprocess_model_parameters( + initialize_model=lambda: torch_model.patch_embed, + custom_preprocessor=create_patch_embed_preprocessor_conv(device), + device=device, + ) + params["patch_embed"] = patch_embed_params + + return params + + return custom_preprocessor + + +@pytest.mark.parametrize( + "batch_size, height, width, dim, num_heads, window_size, depth, overlap_ratio, mlp_ratio, resi_connection", + [ + # SSR config + (1, 64, 64, 180, 3, 16, 3, 0.5, 2, "1conv"), + (1, 64, 64, 180, 6, 16, 6, 0.5, 2, "1conv"), + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat8_b]) +def test_rhag( + device, + batch_size, + height, + width, + dim, + num_heads, + window_size, + depth, + overlap_ratio, + mlp_ratio, + resi_connection, + input_dtype, + weight_dtype, +): + torch.manual_seed(0) + + # Create reference model + ref_model = RHAG( + dim=dim, + input_resolution=(height, width), + depth=depth, + num_heads=num_heads, + window_size=window_size, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + overlap_ratio=overlap_ratio, + mlp_ratio=mlp_ratio, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + img_size=max(height, width), + patch_size=4, + resi_connection=resi_connection, + ) + ref_model.eval() + + # Create input tensors + input_tensor = torch.randn(batch_size, height * width, dim) + x_size = (height, width) + + # Create relative position indices + rpi_sa = create_relative_position_index((window_size, window_size)) + + # attention mask for shifted windows + attn_mask = None + + # Create RPI for OCAB + overlap_win_size = int(window_size * overlap_ratio) + window_size + rpi_oca = torch.zeros((window_size * window_size, overlap_win_size * overlap_win_size), dtype=torch.long) + + # Create params dictionary + params = {"rpi_sa": rpi_sa, "attn_mask": attn_mask, "rpi_oca": rpi_oca} + + # Reference forward pass + with torch.no_grad(): + ref_output = ref_model(input_tensor, x_size, params) + + # Create TTNN model + # parameters = preprocess_model_parameters( + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_model, + custom_preprocessor=create_rhag_preprocessor( + device, depth, window_size, rpi_sa, weight_dtype=weight_dtype, input_dtype=input_dtype + ), + device=device, + ) + + tt_model = TTRHAG( + device=device, + parameters=parameters, + dim=dim, + input_resolution=(height, width), + depth=depth, + num_heads=num_heads, + window_size=window_size, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + overlap_ratio=overlap_ratio, + mlp_ratio=mlp_ratio, + img_size=max(height, width), + patch_size=4, + resi_connection=resi_connection, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=input_dtype, + ) + + # Convert inputs to TTNN format + tt_input = ttnn.from_torch(input_tensor, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + + tt_rpi_sa = ttnn.from_torch(rpi_sa, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32) + + tt_rpi_oca = ttnn.from_torch(rpi_oca, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.uint32) + + tt_params = {"rpi_sa": tt_rpi_sa, "attn_mask": None, "rpi_oca": tt_rpi_oca} + + # TTNN forward pass + tt_output = tt_model(tt_input, x_size, tt_params) + + # Convert back to PyTorch format + tt_torch_output = ttnn.to_torch(tt_output) + + # Compare outputs + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.85) + logger.info(f"pcc: {pcc_message}") + + if does_pass: + logger.info("RHAG Passed!") + else: + logger.warning("RHAG Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_refinement/test_atten_blocks.py b/models/experimental/SSR/tests/tile_refinement/test_atten_blocks.py new file mode 100644 index 000000000000..5faa93c420ff --- /dev/null +++ b/models/experimental/SSR/tests/tile_refinement/test_atten_blocks.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn as nn +import ttnn +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.tile_refinement import AttenBlocks +from models.experimental.SSR.tt.tile_refinement import TTAttenBlocks +from models.experimental.SSR.tests.tile_refinement.test_HAB import ( + create_hab_preprocessor, + create_relative_position_index, +) +from models.experimental.SSR.tests.tile_refinement.test_OCAB import create_ocab_preprocessor +from tests.ttnn.utils_for_testing import check_with_pcc + + +def create_atten_blocks_preprocessor(device, depth, window_size, rpi_sa, weight_dtype, input_dtype): + """Preprocessor for AttenBlocks that handles multiple HAB blocks and one OCAB block""" + + def custom_preprocessor(torch_model, name, ttnn_module_args): + params = {} + + # Preprocess parameters for each HAB block + params["blocks"] = {} + hab_preprocessor = create_hab_preprocessor(device, window_size, rpi_sa, weight_dtype, input_dtype) + for i in range(depth): + params["blocks"][i] = hab_preprocessor(torch_model.blocks[i], f"blocks_{i}", ttnn_module_args) + + # Preprocess parameters for OCAB + ocab_preprocessor = create_ocab_preprocessor(device, weight_dtype=weight_dtype, input_dtype=input_dtype) + params["overlap_attn"] = ocab_preprocessor(torch_model.overlap_attn, "overlap_attn") + + return params + + return custom_preprocessor + + +@pytest.mark.parametrize( + "batch_size, height, width, dim, num_heads, window_size, depth, overlap_ratio, mlp_ratio", + [ + # SSR config + (1, 64, 64, 180, 3, 16, 3, 0.5, 2), + (1, 64, 64, 180, 6, 16, 6, 0.5, 2), + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat8_b]) +def test_atten_blocks( + device, + batch_size, + height, + width, + dim, + num_heads, + window_size, + depth, + overlap_ratio, + mlp_ratio, + input_dtype, + weight_dtype, +): + torch.manual_seed(0) + + # Create reference model + ref_model = AttenBlocks( + dim=dim, + input_resolution=(height, width), + depth=depth, + num_heads=num_heads, + window_size=window_size, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + overlap_ratio=overlap_ratio, + mlp_ratio=mlp_ratio, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ) + ref_model.eval() + + # Create input tensors + input_tensor = torch.randn(batch_size, height * width, dim) + x_size = (height, width) + + # Create relative position indices + rpi_sa = create_relative_position_index((window_size, window_size)) + + # attention mask for shifted windows + attn_mask = None + + # Create RPI for OCAB + overlap_win_size = int(window_size * overlap_ratio) + window_size + rpi_oca = torch.zeros((window_size * window_size, overlap_win_size * overlap_win_size), dtype=torch.long) + + # Create params dictionary + params = {"rpi_sa": rpi_sa, "attn_mask": attn_mask, "rpi_oca": rpi_oca} + + # Reference forward pass + with torch.no_grad(): + ref_output = ref_model(input_tensor, x_size, params) + + # Create TTNN model + parameters = ttnn.model_preprocessing.preprocess_model( + initialize_model=lambda: ref_model, + custom_preprocessor=create_atten_blocks_preprocessor( + device, depth, window_size, rpi_sa, weight_dtype=weight_dtype, input_dtype=input_dtype + ), + device=device, + run_model=lambda model: model(input_tensor, x_size, params), + ) + + tt_model = TTAttenBlocks( + device=device, + parameters=parameters, + dim=dim, + input_resolution=(height, width), + depth=depth, + num_heads=num_heads, + window_size=window_size, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + overlap_ratio=overlap_ratio, + mlp_ratio=mlp_ratio, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=input_dtype, + ) + + # Convert inputs to TTNN format + tt_input = ttnn.from_torch(input_tensor, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + + tt_rpi_sa = ttnn.from_torch(rpi_sa, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32) + + tt_rpi_oca = ttnn.from_torch(rpi_oca, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.uint32) + + tt_params = {"rpi_sa": tt_rpi_sa, "attn_mask": None, "rpi_oca": tt_rpi_oca} + + # TTNN forward pass + tt_output = tt_model(tt_input, x_size, tt_params) + + # Convert back to PyTorch format + tt_torch_output = ttnn.to_torch(tt_output) + + # Compare outputs + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.90) + logger.info(f"pcc: {pcc_message}") + + if does_pass: + logger.info("AttenBlocks Passed!") + else: + logger.warning("AttenBlocks Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_refinement/test_channel_attention.py b/models/experimental/SSR/tests/tile_refinement/test_channel_attention.py new file mode 100644 index 000000000000..5c96cf0f3596 --- /dev/null +++ b/models/experimental/SSR/tests/tile_refinement/test_channel_attention.py @@ -0,0 +1,185 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn +from loguru import logger + +from models.experimental.SSR.tt.tile_refinement import TTChannelAttention +from ttnn.model_preprocessing import preprocess_model_parameters +from tests.ttnn.utils_for_testing import check_with_pcc + +from models.experimental.SSR.reference.SSR.model.tile_refinement import ChannelAttention + + +def create_channel_attention_preprocessor(device, weight_dtype=ttnn.bfloat16, input_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + params = {} + + # Extract the sequential layers + layers = list(torch_model.attention.children()) + conv1 = layers[1] # First Conv2d layer + conv2 = layers[3] # Second Conv2d layer + + conv_config = ttnn.Conv2dConfig(weights_dtype=weight_dtype) + + # Preprocess first convolution + params["conv1"] = { + "weight": ttnn.prepare_conv_weights( + weight_tensor=ttnn.from_torch(conv1.weight, dtype=weight_dtype), + input_memory_config=ttnn.DRAM_MEMORY_CONFIG, + input_layout=ttnn.TILE_LAYOUT, + weights_format="OIHW", + in_channels=conv1.in_channels, + out_channels=conv1.out_channels, + batch_size=1, + input_height=1, + input_width=1, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + has_bias=True, + groups=1, + device=device, + input_dtype=input_dtype, + conv_config=conv_config, + ), + "bias": ttnn.prepare_conv_bias( + bias_tensor=ttnn.from_torch( + conv1.bias.reshape(1, 1, 1, -1), dtype=weight_dtype # Reshape to 4D: [1, 1, 1, out_channels] + ), + input_memory_config=ttnn.DRAM_MEMORY_CONFIG, + input_layout=ttnn.TILE_LAYOUT, + in_channels=conv1.in_channels, + out_channels=conv1.out_channels, + batch_size=1, + input_height=1, + input_width=1, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + device=device, + input_dtype=input_dtype, + conv_config=conv_config, + ), + } + + # Preprocess second convolution + params["conv2"] = { + "weight": ttnn.prepare_conv_weights( + weight_tensor=ttnn.from_torch(conv2.weight, dtype=weight_dtype), + input_memory_config=ttnn.DRAM_MEMORY_CONFIG, + input_layout=ttnn.TILE_LAYOUT, + weights_format="OIHW", + in_channels=conv2.in_channels, + out_channels=conv2.out_channels, + batch_size=1, + input_height=1, + input_width=1, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + has_bias=True, + groups=1, + device=device, + input_dtype=input_dtype, + conv_config=conv_config, + ), + "bias": ttnn.prepare_conv_bias( + bias_tensor=ttnn.from_torch( + conv2.bias.reshape(1, 1, 1, -1), dtype=weight_dtype # Reshape to 4D: [1, 1, 1, out_channels] + ), + input_memory_config=ttnn.DRAM_MEMORY_CONFIG, + input_layout=ttnn.TILE_LAYOUT, + in_channels=conv2.in_channels, + out_channels=conv2.out_channels, + batch_size=1, + input_height=1, + input_width=1, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + device=device, + input_dtype=input_dtype, + conv_config=conv_config, + ), + } + + return params + + return custom_preprocessor + + +@pytest.mark.parametrize( + "batch_size, num_feat, height, width, squeeze_factor", + [ + (1, 180, 64, 64, 30), # SSR config + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat16]) +def test_channel_attention(device, batch_size, num_feat, height, width, squeeze_factor, input_dtype, weight_dtype): + torch.manual_seed(0) + + # Create reference model + ref_model = ChannelAttention(num_feat=num_feat, squeeze_factor=squeeze_factor) + ref_model.eval() + + # Create input tensor + input_tensor = torch.randn(batch_size, num_feat, height, width) + + # Reference forward pass + with torch.no_grad(): + ref_output = ref_model(input_tensor) + + # Create TTNN model + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_model, + custom_preprocessor=create_channel_attention_preprocessor(device, weight_dtype, input_dtype), + device=device, + ) + + memory_config = ttnn.L1_MEMORY_CONFIG + tt_model = TTChannelAttention( + device=device, + parameters=parameters, + num_feat=num_feat, + squeeze_factor=squeeze_factor, + memory_config=memory_config, + dtype=input_dtype, + ) + + # Convert input to TTNN format (NHWC) + tt_input = ttnn.from_torch( + input_tensor.permute(0, 2, 3, 1), + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=input_dtype, + memory_config=memory_config, + ) + + # TTNN forward pass + tt_output = tt_model(tt_input) + + # Convert back to PyTorch format + tt_torch_output = ttnn.to_torch(tt_output) + tt_torch_output = tt_torch_output.permute(0, 3, 1, 2) # NHWC -> NCHW + + # Compare outputs + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.98) + logger.info(f"pcc: {pcc_message}") + + if does_pass: + logger.info("ChannelAttention Passed!") + else: + logger.warning("ChannelAttention Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_refinement/test_patch_embed_tile_refinement.py b/models/experimental/SSR/tests/tile_refinement/test_patch_embed_tile_refinement.py new file mode 100644 index 000000000000..26d5ffab75a5 --- /dev/null +++ b/models/experimental/SSR/tests/tile_refinement/test_patch_embed_tile_refinement.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.tile_refinement import PatchEmbed +from models.experimental.SSR.tt.tile_refinement import TTPatchEmbed +from tests.ttnn.utils_for_testing import check_with_pcc +from ttnn.model_preprocessing import preprocess_model_parameters + + +def create_patch_embed_preprocessor_simple(device): + """Preprocessor for simple PatchEmbed (no conv projection)""" + + def custom_preprocessor(torch_model, name, ttnn_module_args): + params = {} + + if hasattr(torch_model, "norm") and torch_model.norm is not None: + params["norm"] = { + "weight": ttnn.from_torch(torch_model.norm.weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT), + "bias": ttnn.from_torch(torch_model.norm.bias, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT), + } + + return params + + return custom_preprocessor + + +def create_patch_embed_preprocessor_conv(device, weight_dtype=ttnn.bfloat16): + """Preprocessor for PatchEmbed with convolution projection""" + + def custom_preprocessor(torch_model, name, ttnn_module_args): + params = {} + + if hasattr(torch_model, "proj"): + conv_config = ttnn.Conv2dConfig(weights_dtype=ttnn.bfloat16) + + params["proj"] = { + "weight": ttnn.prepare_conv_weights( + weight_tensor=ttnn.from_torch(torch_model.proj.weight, dtype=ttnn.bfloat16), + input_memory_config=ttnn.DRAM_MEMORY_CONFIG, + input_layout=ttnn.TILE_LAYOUT, + weights_format="OIHW", + in_channels=torch_model.proj.in_channels, + out_channels=torch_model.proj.out_channels, + batch_size=1, + input_height=torch_model.img_size[0], + input_width=torch_model.img_size[1], + kernel_size=torch_model.patch_size, + stride=torch_model.patch_size, + padding=(0, 0), + dilation=(1, 1), + has_bias=True, + groups=1, + device=device, + input_dtype=ttnn.bfloat16, + conv_config=conv_config, + ), + "bias": ttnn.prepare_conv_bias( + bias_tensor=ttnn.from_torch(torch_model.proj.bias.reshape(1, 1, 1, -1), dtype=ttnn.bfloat16), + input_memory_config=ttnn.DRAM_MEMORY_CONFIG, + input_layout=ttnn.TILE_LAYOUT, + in_channels=torch_model.proj.in_channels, + out_channels=torch_model.proj.out_channels, + batch_size=1, + input_height=torch_model.img_size[0], + input_width=torch_model.img_size[1], + kernel_size=torch_model.patch_size, + stride=torch_model.patch_size, + padding=(0, 0), + dilation=(1, 1), + groups=1, + device=device, + input_dtype=ttnn.bfloat16, + conv_config=conv_config, + ), + } + + if hasattr(torch_model, "norm") and torch_model.norm is not None: + params["norm"] = { + "weight": ttnn.from_torch(torch_model.norm.weight, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT), + "bias": ttnn.from_torch(torch_model.norm.bias, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT), + } + + return params + + return custom_preprocessor + + +@pytest.mark.parametrize( + "batch_size, img_size, patch_size, in_chans, embed_dim, norm_layer", + [ + (1, 64, 2, 3, 180, None), # TR blk test + (1, 64, 4, 3, 180, None), # HAT blk test + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +def test_patch_embed_simple(device, batch_size, img_size, patch_size, in_chans, embed_dim, norm_layer): + """Test the simplified PatchEmbed implementation (flatten + transpose only)""" + torch.manual_seed(0) + + # Create reference model + ref_model = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer, + ) + ref_model.eval() + + # Create input tensor + input_tensor = torch.randn(batch_size, in_chans, img_size, img_size) + + # Reference forward pass + with torch.no_grad(): + ref_output = ref_model(input_tensor) + + # Create TTNN model + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_model, + custom_preprocessor=create_patch_embed_preprocessor_simple(device), + device=device, + ) + + memory_config = ttnn.L1_MEMORY_CONFIG + tt_model = TTPatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer, + device=device, + parameters=parameters, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + # Convert input to TTNN format + tt_input = ttnn.from_torch(input_tensor, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16) + + # TTNN forward pass + batch_size, channels, height, width = tt_input.shape + tt_input = ttnn.reshape(tt_input, (batch_size, channels, height * width)) + tt_input = ttnn.transpose(tt_input, 1, 2) # [batch, height*width, channels] + tt_output = tt_model(tt_input) + + # Convert back to PyTorch format + tt_torch_output = ttnn.to_torch(tt_output) + + # Compare outputs + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.99) + logger.info(f"pcc: {pcc_message}") + + if does_pass: + logger.info("TR PatchEmbed Passed!") + else: + logger.warning("TR PatchEmbed Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_refinement/test_patch_unembed.py b/models/experimental/SSR/tests/tile_refinement/test_patch_unembed.py new file mode 100644 index 000000000000..e1f550aacdda --- /dev/null +++ b/models/experimental/SSR/tests/tile_refinement/test_patch_unembed.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn +from loguru import logger +from tests.ttnn.utils_for_testing import check_with_pcc + +from models.utility_functions import torch_random + +from models.experimental.SSR.reference.SSR.model.tile_refinement import PatchUnEmbed +from models.experimental.SSR.tt.tile_refinement import TTPatchUnEmbed + + +@pytest.mark.parametrize( + "batch_size, img_size, patch_size, in_chans, embed_dim", + [ + (1, 64, 4, 3, 180), # TR blk test + (1, 64, 2, 3, 180), # HAT blk test + ], +) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +def test_tt_patch_unembed(device, batch_size, img_size, patch_size, in_chans, embed_dim, input_dtype): + """Test TTPatchUnEmbed against PyTorch reference implementation""" + torch.manual_seed(0) + + # Calculate patch dimensions + patches_resolution = [img_size // patch_size, img_size // patch_size] + num_patches = patches_resolution[0] * patches_resolution[1] + + # Create reference PyTorch model + torch_model = PatchUnEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + + # Create TTNN model + tt_model = TTPatchUnEmbed( + mesh_device=device, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim + ) + + # Generate random input tensor (batch_size, num_patches, embed_dim) + torch_input = torch_random((batch_size, num_patches, embed_dim), -1, 1, dtype=torch.float32) + + # Run PyTorch reference + torch_output = torch_model(torch_input, patches_resolution) + + # Convert input to TTNN format + ttnn_input = ttnn.from_torch( + torch_input, dtype=input_dtype, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + # Run TTNN model + ttnn_output = tt_model(ttnn_input, patches_resolution) + ttnn_output_torch = ttnn.to_torch(ttnn_output) + + # Compare outputs + does_pass, pcc_message = check_with_pcc(torch_output, ttnn_output_torch, 0.99) + logger.info(f"pcc: {pcc_message}") + + if does_pass: + logger.info("TR PatchEmbed Passed!") + else: + logger.warning("TR PatchEmbed Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_refinement/test_tile_refinement.py b/models/experimental/SSR/tests/tile_refinement/test_tile_refinement.py new file mode 100644 index 000000000000..b7a2dbff4812 --- /dev/null +++ b/models/experimental/SSR/tests/tile_refinement/test_tile_refinement.py @@ -0,0 +1,280 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.tile_refinement import TileRefinement +from models.experimental.SSR.tests.tile_refinement.test_patch_embed_tile_refinement import ( + create_patch_embed_preprocessor_conv, +) +from models.experimental.SSR.tests.tile_refinement.test_RHAG import create_rhag_preprocessor +from models.experimental.SSR.tests.tile_refinement.test_upsample import create_upsample_preprocessor +from models.experimental.SSR.tt.tile_refinement import TTTileRefinement +from models.experimental.SSR.tests.tile_refinement.test_HAB import create_relative_position_index + +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import tt2torch_tensor +from tests.ttnn.utils_for_testing import check_with_pcc + + +def create_tile_refinement_preprocessor( + device, forward_params, window_size, rpi_sa, depth=[6], weight_dtype=ttnn.bfloat16, input_dtype=ttnn.bfloat16 +): + """Custom preprocessor for TileRefinement model""" + + def custom_preprocessor(torch_model, name, ttnn_module_args): + parameters = {} + parameters["forward_params"] = forward_params + if isinstance(torch_model, TileRefinement): + # Preprocess conv layers + conv_layers = ["conv_first", "conv_after_body", "conv_last"] + + for conv_name in conv_layers: + if hasattr(torch_model, conv_name): + conv_layer = getattr(torch_model, conv_name) + if hasattr(conv_layer, "weight"): # Direct conv layer + conv_config = ttnn.Conv2dConfig(weights_dtype=ttnn.bfloat16) + parameters[conv_name] = { + "weight": ttnn.from_torch(conv_layer.weight, dtype=ttnn.bfloat16), + "bias": ttnn.from_torch(conv_layer.bias.reshape(1, 1, 1, -1), dtype=ttnn.bfloat16), + } + + if hasattr(torch_model, "conv_before_upsample") and torch_model.conv_before_upsample is not None: + conv_layer = torch_model.conv_before_upsample[0] # Conv2d layer + conv_config = ttnn.Conv2dConfig(weights_dtype=ttnn.bfloat16) + parameters["conv_before_upsample"] = { + "weight": ttnn.from_torch(conv_layer.weight, dtype=ttnn.bfloat16), + "bias": ttnn.from_torch(conv_layer.bias.reshape(1, 1, 1, -1), dtype=ttnn.bfloat16), + } + + # Preprocess layer norm + if hasattr(torch_model, "norm"): + dim = torch_model.norm.weight.size(0) + padded_dim = ((dim + 31) // 32) * 32 # Round up to nearest multiple of 32 = 192 + + norm_weight_padded = torch.nn.functional.pad(torch_model.norm.weight, (0, padded_dim - dim)) + norm_bias_padded = torch.nn.functional.pad(torch_model.norm.bias, (0, padded_dim - dim)) + + norm_weight = norm_weight_padded.view(1, 1, padded_dim // 32, 32) + norm_bias = norm_bias_padded.view(1, 1, padded_dim // 32, 32) + + parameters["norm"] = { + "weight": ttnn.from_torch( + norm_weight, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device + ), + "bias": ttnn.from_torch( + norm_bias, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device + ), + } + + # Preprocess relative position indices + if hasattr(torch_model, "relative_position_index_SA"): + parameters["relative_position_index_SA"] = ttnn.from_torch( + torch_model.relative_position_index_SA, dtype=ttnn.int32 + ) + + if hasattr(torch_model, "relative_position_index_OCA"): + parameters["relative_position_index_OCA"] = ttnn.from_torch( + torch_model.relative_position_index_OCA, dtype=ttnn.int32 + ) + if hasattr(torch_model, "patch_embed"): + patch_embed_params = preprocess_model_parameters( + initialize_model=lambda: torch_model.patch_embed, + custom_preprocessor=create_patch_embed_preprocessor_conv(device, weight_dtype=weight_dtype), + device=device, + ) + parameters["patch_embed"] = patch_embed_params + + if hasattr(torch_model, "upsample"): + upsample_params = preprocess_model_parameters( + initialize_model=lambda: torch_model.upsample, + custom_preprocessor=create_upsample_preprocessor(device), + device=device, + ) + parameters["upsample"] = upsample_params + + if hasattr(torch_model, "layers"): + for i in range(len(torch_model.layers)): + rhag_params = preprocess_model_parameters( + initialize_model=lambda: torch_model.layers[i], + custom_preprocessor=create_rhag_preprocessor( + device, + depth=depth[i], + window_size=window_size, + rpi_sa=rpi_sa, + weight_dtype=weight_dtype, + input_dtype=input_dtype, + ), + device=device, + ) + parameters[f"layers.{i}"] = rhag_params + + return parameters + + return custom_preprocessor + + +def get_precision_config(precision_type): + """Get precision configuration for the given type""" + if precision_type == "performance": + return ttnn.bfloat8_b, ttnn.bfloat8_b + elif precision_type == "accuracy": + return ttnn.bfloat16, ttnn.bfloat16 + else: + raise ValueError(f"Unknown precision type: {precision_type}") + + +@pytest.mark.parametrize( + "img_size, patch_size, embed_dim, depths, num_heads, window_size, mlp_ratio, upscale, input_shape", + [ + (64, 2, 180, [1], [1], 16, 2, 4, (3, 3, 64, 64)), + (64, 2, 180, [6, 6, 6, 6, 6, 6], [6, 6, 6, 6, 6, 6], 16, 2, 4, (3, 3, 64, 64)), + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize( + "precision_config", + [ + lambda: get_precision_config("performance"), + lambda: get_precision_config("accuracy"), + ], + ids=["performance", "accuracy"], +) +def test_tile_refinement( + device, + img_size, + patch_size, + embed_dim, + depths, + num_heads, + window_size, + mlp_ratio, + upscale, + input_shape, + precision_config, +): + """Test TTTileRefinement model against PyTorch reference""" + + # Get data types from precision configuration + input_dtype, weight_dtype = precision_config() + + # Create input tensor + x = torch.randn(input_shape) + overlap_ratio = 0.5 + # Create reference PyTorch model + ref_model = TileRefinement( + img_size=img_size, + patch_size=patch_size, + in_chans=3, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + overlap_ratio=overlap_ratio, + mlp_ratio=mlp_ratio, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + use_checkpoint=False, + upscale=upscale, + img_range=1.0, + upsampler="pixelshuffle", + resi_connection="1conv", + ) + + rpi_sa = create_relative_position_index((window_size, window_size)) + + # attention mask for shifted windows + attn_mask = None + + # Create RPI for OCAB + overlap_win_size = int(window_size * overlap_ratio) + window_size + rpi_oca = torch.zeros((window_size * window_size, overlap_win_size * overlap_win_size), dtype=torch.long) + + tt_rpi_sa = ttnn.from_torch(rpi_sa, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32) + + tt_rpi_oca = ttnn.from_torch(rpi_oca, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.uint32) + + tt_params = {"rpi_sa": tt_rpi_sa, "attn_mask": attn_mask, "rpi_oca": tt_rpi_oca} + + ref_model.eval() + + # Get reference output (both image and features) + with torch.no_grad(): + ref_output, ref_features = ref_model(x) + + # Preprocess model parameters + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_model, + custom_preprocessor=create_tile_refinement_preprocessor( + device, tt_params, window_size, rpi_sa, depth=depths, weight_dtype=weight_dtype, input_dtype=input_dtype + ), + device=device, + ) + + # Create TTNN model + tt_model = TTTileRefinement( + device=device, + parameters=parameters, + img_size=img_size, + patch_size=patch_size, + in_chans=3, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + overlap_ratio=0.5, + mlp_ratio=mlp_ratio, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + upscale=upscale, + img_range=1.0, + upsampler="pixelshuffle", + resi_connection="1conv", + dtype=input_dtype, + ) + + # Convert input to TTNN tensor + tt_input = ttnn.from_torch(x, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + + # Run TTNN model + tt_output, tt_features = tt_model(tt_input) + + # Convert back to torch tensors + tt_torch_output = tt2torch_tensor(tt_output) + tt_torch_features = tt2torch_tensor(tt_features) + + tt_torch_output = tt_torch_output.permute(0, 3, 1, 2) + tt_torch_features = tt_torch_features.permute(0, 3, 1, 2) + + # Compare outputs + output_pass, output_pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.90) + features_pass, features_pcc_message = check_with_pcc(ref_features, tt_torch_features, 0.90) + logger.info(f"output_pcc: {output_pcc_message}") + logger.info(f"features_pcc: {features_pcc_message}") + + if output_pass and features_pass: + logger.info("TTTileRefinement Test Passed!") + else: + logger.warning("TTTileRefinement Test Failed!") + + assert output_pass, f"Output comparison failed: {output_pcc_message}" + assert features_pass, f"Features comparison failed: {features_pcc_message}" diff --git a/models/experimental/SSR/tests/tile_refinement/test_upsample.py b/models/experimental/SSR/tests/tile_refinement/test_upsample.py new file mode 100644 index 000000000000..ceb0070706d4 --- /dev/null +++ b/models/experimental/SSR/tests/tile_refinement/test_upsample.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn as nn +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.tile_refinement import Upsample +from models.experimental.SSR.tt.tile_refinement import TTUpsample + +from models.utility_functions import ( + tt2torch_tensor, +) +from tests.ttnn.utils_for_testing import check_with_pcc + + +def create_upsample_preprocessor(device): + def custom_preprocessor(model, name): + """Custom preprocessor for converting PyTorch weights to TTNN format""" + parameters = {} + if isinstance(model, Upsample): + conv_idx = 0 + for i, layer in enumerate(model): + if isinstance(layer, nn.Conv2d): + parameters[f"conv_{conv_idx}"] = {} + parameters[f"conv_{conv_idx}"]["weight"] = ttnn.from_torch( + layer.weight, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT + ) + if layer.bias is not None: + parameters[f"conv_{conv_idx}"]["bias"] = ttnn.from_torch( + torch.reshape(layer.bias, (1, 1, 1, -1)), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT + ) + conv_idx += 1 + + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "scale,num_feat,batch_size,input_size", + [(4, 64, 1, 256)], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +def test_upsample(device, scale, num_feat, batch_size, input_size, input_dtype): + """Test Upsample block against PyTorch reference""" + torch.manual_seed(0) + if batch_size == 8: + pytest.xfail( + "Statically allocated circular buffers in program 136 clash with L1 buffers on core range [(x=0,y=0) - (x=7,y=7)]. L1 buffer allocated at 118272 and static circular buffer region ends at 435840" + ) + + # Create PyTorch reference model + torch_model = Upsample(scale, num_feat).eval() + + # Create test input + torch_input = torch.randn(batch_size, num_feat, input_size, input_size) + torch_output = torch_model(torch_input) + + # Preprocess model parameters + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_model, + custom_preprocessor=create_upsample_preprocessor(device), + device=device, + ) + + # Convert input to TTNN format (NHWC) + ttnn_input = torch.permute(torch_input, (0, 2, 3, 1)) + ttnn_input = ttnn.from_torch(ttnn_input, dtype=input_dtype, layout=ttnn.TILE_LAYOUT, device=device) + + # Create TTNN model and run inference + ttnn_model = TTUpsample(scale, num_feat, device) + + ttnn_output = ttnn_model(ttnn_input, parameters=parameters) + tt_torch_output = tt2torch_tensor(ttnn_output) + tt_torch_output = tt_torch_output.permute(0, 3, 1, 2) + + does_pass, pcc_message = check_with_pcc(torch_output, tt_torch_output, 0.99) + + logger.info(f"pcc: {pcc_message}") + + if does_pass: + logger.info("Upsample Passed!") + else: + logger.warning("Upsample Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_refinement/test_window_attn_tr.py b/models/experimental/SSR/tests/tile_refinement/test_window_attn_tr.py new file mode 100644 index 000000000000..54ca8ac56e23 --- /dev/null +++ b/models/experimental/SSR/tests/tile_refinement/test_window_attn_tr.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.tile_refinement import WindowAttention +from models.experimental.SSR.tt.tile_refinement import TTWindowAttentionTR +from ttnn.model_preprocessing import preprocess_linear_bias, preprocess_linear_weight +from tests.ttnn.utils_for_testing import check_with_pcc + + +def create_window_attention_preprocessor(device, window_size=None, rpi=None, tile_size=32, weight_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + params = {} + + # QKV linear layer + num_heads = torch_model.num_heads + head_size = torch_model.qkv.weight.shape[0] // (3 * num_heads) + # nearest multiple of tile_size + padded_head_size = ((head_size + tile_size - 1) // tile_size) * tile_size + qkv_weight = torch_model.qkv.weight + qkv_bias = torch_model.qkv.bias + + if padded_head_size != head_size: + # Weight: [3*num_heads*head_size, in_features] + qkv_weight = qkv_weight.view(3 * num_heads, head_size, -1) + qkv_weight = torch.nn.functional.pad(qkv_weight, (0, 0, 0, padded_head_size - head_size), "constant", 0) + qkv_weight = qkv_weight.reshape(3 * num_heads * padded_head_size, -1) + + if qkv_bias is not None: + # Bias: [3*num_heads, head_size] + qkv_bias = qkv_bias.view(3 * num_heads, head_size) + qkv_bias = torch.nn.functional.pad(qkv_bias, (0, padded_head_size - head_size), "constant", 0) + qkv_bias = qkv_bias.reshape(3 * num_heads * padded_head_size) + + params["qkv"] = { + "weight": preprocess_linear_weight(qkv_weight, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT), + "bias": preprocess_linear_bias(qkv_bias, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT) + if qkv_bias is not None + else None, + } + + # Projection layer + params["proj"] = { + "weight": preprocess_linear_weight(torch_model.proj.weight, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT), + "bias": preprocess_linear_bias(torch_model.proj.bias, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT) + if torch_model.proj.bias is not None + else None, + } + + # Relative position bias table + relative_position_bias = torch_model.relative_position_bias_table[rpi.view(-1)].view( + window_size[0] * window_size[1], window_size[0] * window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + params["relative_position_bias"] = ttnn.from_torch( + relative_position_bias.unsqueeze(0), dtype=weight_dtype, layout=ttnn.TILE_LAYOUT + ) + return params + + return custom_preprocessor + + +@pytest.mark.parametrize( + "batch_size, num_windows, window_size, dim, num_heads", + [ + (1, 16, (16, 16), 180, 6), # SSR config + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat8_b]) +def test_window_attention(device, batch_size, num_windows, window_size, dim, num_heads, input_dtype, weight_dtype): + torch.manual_seed(0) + + # Create reference model + ref_model = WindowAttention( + dim=dim, + window_size=window_size, + num_heads=num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ) + ref_model.eval() + + # Create input tensors + window_area = window_size[0] * window_size[1] + input_tensor = torch.randn(batch_size * num_windows, window_area, dim) + + # Create relative position index + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += window_size[0] - 1 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + rpi = relative_coords.sum(-1) + + # Reference forward pass + with torch.no_grad(): + ref_output = ref_model(input_tensor, rpi=rpi, mask=None) + + # Create TTNN model + parameters = ttnn.model_preprocessing.preprocess_model( + initialize_model=lambda: ref_model, + custom_preprocessor=create_window_attention_preprocessor(device, window_size, rpi, weight_dtype=weight_dtype), + device=device, + run_model=lambda model: model(input_tensor, rpi=rpi, mask=None), + ) + + memory_config = ttnn.L1_MEMORY_CONFIG + + tt_model = TTWindowAttentionTR( + device=device, + parameters=parameters, + dim=dim, + window_size=window_size, + num_heads=num_heads, + memory_config=memory_config, + dtype=input_dtype, + ) + + # Convert inputs to TTNN format + tt_input = ttnn.from_torch( + input_tensor, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype, memory_config=memory_config + ) + + tt_rpi = ttnn.from_torch(rpi, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32) + + # TTNN forward pass + tt_output = tt_model(tt_input, rpi=tt_rpi, mask=None) + + # Convert back to PyTorch format + tt_torch_output = ttnn.to_torch(tt_output) + + # Compare outputs + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.97) + + logger.info(f"pcc: {pcc_message}") + + if does_pass: + logger.info("Window Attention Passed!") + else: + logger.warning("Window Attention Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_selection/test_basic_block.py b/models/experimental/SSR/tests/tile_selection/test_basic_block.py new file mode 100644 index 000000000000..4d859a516b70 --- /dev/null +++ b/models/experimental/SSR/tests/tile_selection/test_basic_block.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from loguru import logger + +import ttnn +from models.experimental.SSR.tt.tile_selection import TTBasicLayer, TTPatchMerging +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import tt2torch_tensor +from tests.ttnn.utils_for_testing import check_with_pcc + +from models.experimental.SSR.tests.tile_selection.test_swin_transformer_block import ( + create_swin_transformer_block_preprocessor, +) +from models.experimental.SSR.tests.tile_selection.test_patch_merging import create_patch_merging_preprocessor +from models.experimental.SSR.reference.SSR.model.net_blocks import PatchMerging, BasicLayer + + +def create_basic_layer_preprocessor(device, dim, weight_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + params = {"blocks": {}} + + # Process each transformer block + for i, block in enumerate(torch_model.blocks): + params["blocks"][i] = preprocess_model_parameters( + initialize_model=lambda: block, + custom_preprocessor=create_swin_transformer_block_preprocessor(device, weight_dtype), + device=device, + ) + + # Process downsampling layer if present + if torch_model.downsample is not None: + params["downsample"] = preprocess_model_parameters( + initialize_model=lambda: torch_model.downsample, + custom_preprocessor=create_patch_merging_preprocessor(device, dim, weight_dtype=weight_dtype), + device=device, + ) + + return params + + return custom_preprocessor + + +@pytest.mark.parametrize( + "batch_size, input_resolution, dim, depth, num_heads, window_size, has_downsample", + [ + (3, (128, 128), 96, 2, 3, 7, True), + (3, (64, 64), 192, 2, 3, 7, True), + (3, (32, 32), 384, 2, 3, 7, True), + (3, (16, 16), 768, 2, 3, 7, True), + (3, (8, 8), 1536, 2, 3, 7, True), + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}]) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat8_b]) +def test_basic_layer( + device, batch_size, input_resolution, dim, depth, num_heads, window_size, has_downsample, input_dtype, weight_dtype +): + torch.manual_seed(0) + + H, W = input_resolution + + # Create reference model + downsample = PatchMerging if has_downsample else None + ref_layer = BasicLayer( + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=4.0, + downsample=downsample, + ) + ref_layer.eval() + + # Create input tensor [B, H*W, C] + input_tensor = torch.randn(batch_size, H * W, dim) + + # Reference forward pass + ref_output = ref_layer(input_tensor) + + # Create ttnn model + params = preprocess_model_parameters( + initialize_model=lambda: ref_layer, + custom_preprocessor=create_basic_layer_preprocessor(device, dim, weight_dtype), + device=device, + ) + + tt_layer = TTBasicLayer( + device=device, + parameters=params, + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=4.0, + downsample=TTPatchMerging if has_downsample else None, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=input_dtype, + ) + + # Convert input to ttnn + tt_input = ttnn.from_torch(input_tensor, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + + # ttnn forward pass + tt_output = tt_layer(tt_input) + tt_torch_output = tt2torch_tensor(tt_output) + + # Compare outputs + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.98) + logger.info(f"PCC: {pcc_message}") + + if does_pass: + logger.info("BasicLayer Passed!") + else: + logger.warning("BasicLayer Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_selection/test_mask_token_inference.py b/models/experimental/SSR/tests/tile_selection/test_mask_token_inference.py new file mode 100644 index 000000000000..cba7f4692610 --- /dev/null +++ b/models/experimental/SSR/tests/tile_selection/test_mask_token_inference.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +from loguru import logger +from ttnn.model_preprocessing import preprocess_model_parameters, preprocess_linear_bias, preprocess_linear_weight +from models.experimental.SSR.reference.SSR.model.tile_selection import mask_token_inference +from models.experimental.SSR.tt.tile_selection import TTMaskTokenInference + +from models.utility_functions import tt2torch_tensor +from tests.ttnn.utils_for_testing import check_with_pcc + + +def create_mask_token_inference_preprocessor(device, weight_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + parameters = {} + if ( + hasattr(torch_model, "norm") + and hasattr(torch_model, "q") + and hasattr(torch_model, "k") + and hasattr(torch_model, "v") + and hasattr(torch_model, "proj") + ): + # Layer norm parameters + parameters["norm"] = {} + parameters["norm"]["weight"] = ttnn.from_torch( + torch_model.norm.weight, dtype=weight_dtype, device=device, layout=ttnn.TILE_LAYOUT + ) + parameters["norm"]["bias"] = ttnn.from_torch( + torch_model.norm.bias, dtype=weight_dtype, device=device, layout=ttnn.TILE_LAYOUT + ) + + # QKV linear layers + parameters["proj"] = {} + + parameters["proj"]["weight"] = preprocess_linear_weight(torch_model.proj.weight, dtype=weight_dtype) + + qkv_weight = torch.cat([torch_model.q.weight, torch_model.k.weight, torch_model.v.weight], dim=0) + + parameters["qkv"] = {} + parameters["qkv"]["weight"] = preprocess_linear_weight(qkv_weight, dtype=weight_dtype) + + if torch_model.q.bias is not None: + qkv_bias = torch.cat([torch_model.q.bias, torch_model.k.bias, torch_model.v.bias], dim=0) + parameters["qkv"]["bias"] = preprocess_linear_bias(qkv_bias, dtype=weight_dtype) + + parameters["proj"]["bias"] = preprocess_linear_bias(torch_model.proj.bias, dtype=weight_dtype) + + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "input_shape, dim, num_heads", + (((3, 17, 3072), 3072, 1),), +) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat8_b]) +def test_mask_token_inference(device, input_shape, dim, num_heads, input_dtype, weight_dtype): + # Create test input [B, N, C] + input_tensor = torch.randn(input_shape) + + ref_layer = mask_token_inference(dim=dim, num_heads=num_heads, qkv_bias=False) + ref_layer.eval() + ref_output = ref_layer(input_tensor) + + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_layer, + custom_preprocessor=create_mask_token_inference_preprocessor(device, weight_dtype), + device=device, + ) + + tt_layer = TTMaskTokenInference(device=device, parameters=parameters, dim=dim, num_heads=num_heads) + + tt_input = ttnn.from_torch(input_tensor, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + tt_output = tt_layer(tt_input) + tt_torch_output = tt2torch_tensor(tt_output) + + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.99) + + logger.info(f"PCC: {pcc_message}") + + if does_pass: + logger.info("MaskTokenInference Passed!") + else: + logger.warning("MaskTokenInference Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_selection/test_patch_embed.py b/models/experimental/SSR/tests/tile_selection/test_patch_embed.py new file mode 100644 index 000000000000..370d4f108c5e --- /dev/null +++ b/models/experimental/SSR/tests/tile_selection/test_patch_embed.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.net_blocks import PatchEmbed +from models.experimental.SSR.tt.tile_selection import TTPatchEmbed +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import tt2torch_tensor +from tests.ttnn.utils_for_testing import check_with_pcc + + +def create_patch_embed_preprocessor(device, weight_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + parameters = {} + if isinstance(torch_model, PatchEmbed): + # Extract Conv2d weights + conv_weight = torch_model.proj.weight # Shape: [out_channels, in_channels, kernel_height, kernel_width] + conv_bias = torch_model.proj.bias # Shape: [out_channels] + + parameters["proj"] = {} + # Keep weights in 4D format + parameters["proj"]["weight"] = ttnn.from_torch( + conv_weight, dtype=weight_dtype, layout=ttnn.ROW_MAJOR_LAYOUT + ) + # Reshape bias to [1, 1, 1, out_channels] format expected by conv2d + conv_bias_reshaped = conv_bias.reshape(1, 1, 1, -1) + parameters["proj"]["bias"] = ttnn.from_torch( + conv_bias_reshaped, dtype=weight_dtype, layout=ttnn.ROW_MAJOR_LAYOUT + ) + + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize("img_size, ch, patch_size, embed_dim, norm_layer", ((256, 3, 2, 96, None),)) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}]) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat16]) +def test_patch_embed(device, img_size, ch, patch_size, embed_dim, norm_layer, input_dtype, weight_dtype): + input_shape = (3, ch, img_size, img_size) + + x = torch.randn(input_shape) + + ref_layer = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=ch, + embed_dim=embed_dim, + norm_layer=norm_layer, + ) + + ref_output = ref_layer(x) + + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_layer, + custom_preprocessor=create_patch_embed_preprocessor(device, weight_dtype), + device=device, + ) + + tt_layer = TTPatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=ch, + embed_dim=embed_dim, + device=device, + parameters=parameters, + dtype=input_dtype, + ) + + # NCHW -> NHWC + x = x.permute(0, 2, 3, 1) + + tt_input = ttnn.from_torch(x, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + tt_output = tt_layer(tt_input) + tt_torch_output = tt2torch_tensor(tt_output) + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.99) + + logger.info(f"PCC: {pcc_message}") + + if does_pass: + logger.info("PatchEmbed Passed!") + else: + logger.warning("PatchEmbed Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_selection/test_patch_merging.py b/models/experimental/SSR/tests/tile_selection/test_patch_merging.py new file mode 100644 index 000000000000..2b119e2f3a0d --- /dev/null +++ b/models/experimental/SSR/tests/tile_selection/test_patch_merging.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from loguru import logger + +import ttnn +from models.experimental.SSR.reference.SSR.model.net_blocks import PatchMerging +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import tt2torch_tensor +from tests.ttnn.utils_for_testing import check_with_pcc +from models.experimental.SSR.tt.tile_selection import TTPatchMerging + + +def create_patch_merging_preprocessor(device, dim, weight_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + params = {} + + # Create conv kernels for patch merging (same as in forward pass) + kernel_top_left = torch.zeros(dim, 1, 2, 2, dtype=torch.bfloat16) + kernel_top_left[:, 0, 0, 0] = 1.0 + + kernel_bottom_left = torch.zeros(dim, 1, 2, 2, dtype=torch.bfloat16) + kernel_bottom_left[:, 0, 1, 0] = 1.0 + + kernel_top_right = torch.zeros(dim, 1, 2, 2, dtype=torch.bfloat16) + kernel_top_right[:, 0, 0, 1] = 1.0 + + kernel_bottom_right = torch.zeros(dim, 1, 2, 2, dtype=torch.bfloat16) + kernel_bottom_right[:, 0, 1, 1] = 1.0 + + # Convert to TTNN tensors + params["conv_kernels"] = { + "top_left": ttnn.from_torch(kernel_top_left, device=device, dtype=ttnn.bfloat16), + "bottom_left": ttnn.from_torch(kernel_bottom_left, device=device, dtype=ttnn.bfloat16), + "top_right": ttnn.from_torch(kernel_top_right, device=device, dtype=ttnn.bfloat16), + "bottom_right": ttnn.from_torch(kernel_bottom_right, device=device, dtype=ttnn.bfloat16), + } + + # Linear reduction layer + params["reduction"] = { + "weight": ttnn.from_torch( + torch_model.reduction.weight.transpose(0, 1), # Transpose for ttnn.linear + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + device=device, + ) + } + + # Layer normalization + params["norm"] = { + "weight": ttnn.from_torch( + torch_model.norm.weight, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + device=device, + ), + "bias": ttnn.from_torch( + torch_model.norm.bias, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT, + device=device, + ), + } + + return params + + return custom_preprocessor + + +@pytest.mark.parametrize( + "batch_size, input_resolution, dim", + ( + (3, (128, 128), 96), + (3, (64, 64), 192), + (3, (32, 32), 384), + (3, (16, 16), 768), + (3, (8, 8), 1536), + ), +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}]) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat8_b]) +def test_patch_merging(device, batch_size, input_resolution, dim, input_dtype, weight_dtype): + torch.manual_seed(0) + + H, W = input_resolution + + # Create reference model + ref_layer = PatchMerging(input_resolution=input_resolution, dim=dim) + ref_layer.eval() + + # Create input tensor [B, H*W, C] + input_tensor = torch.randn(batch_size, H * W, dim) + + # Reference forward pass + ref_output = ref_layer(input_tensor) + + # Create ttnn model + params = preprocess_model_parameters( + initialize_model=lambda: ref_layer, + custom_preprocessor=create_patch_merging_preprocessor(device, dim, weight_dtype), + device=device, + ) + + tt_layer = TTPatchMerging( + device=device, + parameters=params, + input_resolution=input_resolution, + dim=dim, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=input_dtype, + ) + + # Convert input to ttnn + tt_input = ttnn.from_torch( + input_tensor, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + # ttnn forward pass + tt_output = tt_layer(tt_input) + tt_torch_output = tt2torch_tensor(tt_output) + + # Compare outputs + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.99) + logger.info(f"PCC: {pcc_message}") + + if does_pass: + logger.info("PatchMerging Passed!") + else: + logger.warning("PatchMerging Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_selection/test_swin_transformer_block.py b/models/experimental/SSR/tests/tile_selection/test_swin_transformer_block.py new file mode 100644 index 000000000000..d7024d811747 --- /dev/null +++ b/models/experimental/SSR/tests/tile_selection/test_swin_transformer_block.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.net_blocks import SwinTransformerBlock +from models.experimental.SSR.tt.tile_selection import TTSwinTransformerBlock +from models.experimental.SSR.tests.common.test_mlp import create_mlp_preprocessor +from models.experimental.SSR.tests.tile_selection.test_window_attn import create_window_attention_preprocessor +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import tt2torch_tensor +from tests.ttnn.utils_for_testing import check_with_pcc + + +def create_swin_transformer_block_preprocessor(device, weight_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + parameters = {} + + if hasattr(torch_model, "attn"): + # Preprocess attention parameters + parameters["attn"] = preprocess_model_parameters( + initialize_model=lambda: torch_model.attn, + custom_preprocessor=create_window_attention_preprocessor(device, weight_dtype), + device=device, + ) + + # Preprocess layer normalization parameters + parameters["norm1"] = {} + parameters["norm1"]["weight"] = ttnn.from_torch( + torch_model.norm1.weight, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device + ) + parameters["norm1"]["bias"] = ttnn.from_torch( + torch_model.norm1.bias, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device + ) + + parameters["norm2"] = {} + parameters["norm2"]["weight"] = ttnn.from_torch( + torch_model.norm2.weight, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device + ) + parameters["norm2"]["bias"] = ttnn.from_torch( + torch_model.norm2.bias, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT, device=device + ) + + # Preprocess MLP parameters + parameters["mlp"] = preprocess_model_parameters( + initialize_model=lambda: torch_model.mlp, + custom_preprocessor=create_mlp_preprocessor(device, weight_dtype), + device=device, + ) + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "batch_size, height, width, dim, num_heads, window_size, shift_size, mlp_ratio", + ( + (3, 128, 128, 96, 3, 7, 0, 4.0), + (3, 128, 128, 96, 3, 7, 3, 4.0), + (3, 64, 64, 96, 3, 7, 0, 4.0), + (3, 64, 64, 96, 3, 7, 3, 4.0), + (3, 32, 32, 96, 3, 7, 0, 4.0), + (3, 32, 32, 96, 3, 7, 3, 4.0), + (3, 16, 16, 96, 3, 7, 0, 4.0), + (3, 16, 16, 96, 3, 7, 3, 4.0), + (3, 8, 8, 96, 3, 7, 0, 4.0), + (3, 8, 8, 96, 3, 7, 3, 4.0), + ), +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}]) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat8_b]) +def test_swin_transformer_block( + device, batch_size, height, width, dim, num_heads, window_size, shift_size, mlp_ratio, input_dtype, weight_dtype +): + # Create input tensor + input_shape = (batch_size, height * width, dim) + x = torch.randn(input_shape) + + # Create reference model + ref_layer = SwinTransformerBlock( + dim=dim, + input_resolution=(height, width), + num_heads=num_heads, + window_size=window_size, + shift_size=shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + ) + + # Get reference output + ref_output = ref_layer(x) + + # Preprocess model parameters + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_layer, + custom_preprocessor=create_swin_transformer_block_preprocessor(device, weight_dtype), + device=device, + ) + + # Create TTNN model + tt_layer = TTSwinTransformerBlock( + parameters=parameters, + device=device, + dim=dim, + input_resolution=(height, width), + num_heads=num_heads, + window_size=window_size, + shift_size=shift_size, + mlp_ratio=mlp_ratio, + dtype=input_dtype, + ) + + # Convert input to TTNN tensor + tt_input = ttnn.from_torch(x, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + tt_input = ttnn.to_memory_config(tt_input, ttnn.L1_MEMORY_CONFIG) + + # Run forward pass + tt_output = tt_layer(tt_input) + + # Convert back to torch + tt_torch_output = tt2torch_tensor(tt_output) + + # Compare outputs + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.98) + logger.info(f"PCC: {pcc_message}") + + if does_pass: + logger.info("SwinTransformerBlock Passed!") + else: + logger.warning("SwinTransformerBlock Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tests/tile_selection/test_tile_selection.py b/models/experimental/SSR/tests/tile_selection/test_tile_selection.py new file mode 100644 index 000000000000..9b4cbbb250a3 --- /dev/null +++ b/models/experimental/SSR/tests/tile_selection/test_tile_selection.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +from loguru import logger +from ttnn.model_preprocessing import preprocess_model_parameters, preprocess_linear_bias, preprocess_linear_weight +from models.experimental.SSR.reference.SSR.model.tile_selection import TileSelection +from models.experimental.SSR.tt.tile_selection import TTTileSelection +from models.experimental.SSR.tests.tile_selection.test_patch_embed import create_patch_embed_preprocessor +from models.experimental.SSR.tests.tile_selection.test_basic_block import create_basic_layer_preprocessor +from models.experimental.SSR.tests.common.test_mlp import create_mlp_preprocessor +from models.experimental.SSR.tests.tile_selection.test_mask_token_inference import ( + create_mask_token_inference_preprocessor, +) +from models.utility_functions import tt2torch_tensor +from tests.ttnn.utils_for_testing import check_with_pcc +from models.experimental.SSR.tests.tile_refinement.test_tile_refinement import get_precision_config + + +def create_tile_selection_preprocessor(device, dim=96, weight_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + parameters = {} + + # mask token embedding parameters + if hasattr(torch_model, "mask_token"): + parameters["mask_token"] = {} + parameters["mask_token"]["weight"] = ttnn.from_torch( + torch_model.mask_token.weight, dtype=weight_dtype, device=device, layout=ttnn.TILE_LAYOUT + ) + + # patch embedding parameters + if hasattr(torch_model, "patch_embed"): + patch_embed_params = preprocess_model_parameters( + initialize_model=lambda: torch_model.patch_embed, + custom_preprocessor=create_patch_embed_preprocessor(device), + device=device, + ) + parameters["patch_embed"] = patch_embed_params + + # encoder layers parameters + if hasattr(torch_model, "layers"): + for i, layer in enumerate(torch_model.layers): + layer_dim = int(dim * 2**i) + layer_params = preprocess_model_parameters( + initialize_model=lambda l=layer: l, + custom_preprocessor=create_basic_layer_preprocessor(device, layer_dim, weight_dtype), + device=device, + ) + parameters[f"layers.{i}"] = layer_params + + # layer norm parameters + for norm_name in ["norm1", "norm2", "norm3", "mlp_norm1", "mlp_norm2", "mlp_norm3"]: + if hasattr(torch_model, norm_name): + norm_layer = getattr(torch_model, norm_name) + parameters[norm_name] = {} + parameters[norm_name]["weight"] = ttnn.from_torch( + norm_layer.weight, dtype=weight_dtype, device=device, layout=ttnn.TILE_LAYOUT + ) + parameters[norm_name]["bias"] = ttnn.from_torch( + norm_layer.bias, dtype=weight_dtype, device=device, layout=ttnn.TILE_LAYOUT + ) + + # MLP parameters + for mlp_name in ["fea_mlp1", "fea_mlp2", "fea_mlp3", "mlp1", "mlp2", "mlp3"]: + if hasattr(torch_model, mlp_name): + mlp = getattr(torch_model, mlp_name) + mlp_params = preprocess_model_parameters( + initialize_model=lambda m=mlp: m, + custom_preprocessor=create_mlp_preprocessor(device, weight_dtype), + device=device, + ) + parameters[mlp_name] = mlp_params + + # linear classification layer parameters + for linear_name in ["linear1", "linear2", "linear3"]: + if hasattr(torch_model, linear_name): + linear_layer = getattr(torch_model, linear_name) + parameters[linear_name] = {} + parameters[linear_name]["weight"] = preprocess_linear_weight(linear_layer.weight, dtype=weight_dtype) + parameters[linear_name]["bias"] = preprocess_linear_bias(linear_layer.bias, dtype=weight_dtype) + + # mask token inference modules parameters + for mask_name in ["mask_pre1", "mask_pre2", "mask_pre3"]: + if hasattr(torch_model, mask_name): + mask_module = getattr(torch_model, mask_name) + mask_params = preprocess_model_parameters( + initialize_model=lambda m=mask_module: m, + custom_preprocessor=create_mask_token_inference_preprocessor(device, weight_dtype), + device=device, + ) + parameters[mask_name] = mask_params + + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "image_size, patch_size, token_size, num_cls", + [ + (256, 2, 4, 1), + ], +) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}]) +@pytest.mark.parametrize( + "precision_config", + [ + lambda: get_precision_config("performance"), + lambda: get_precision_config("accuracy"), + ], + ids=["performance", "accuracy"], +) +def test_tile_selection(device, image_size, patch_size, token_size, num_cls, precision_config): + """Test TileSelection module against PyTorch reference for correctness""" + + # Get data types from precision configuration + input_dtype, weight_dtype = precision_config() + + # Create mock args object + class Args: + def __init__(self, imgsz, patchsz, token_size, dim): + self.imgsz = imgsz + self.patchsz = patchsz + self.token_size = token_size + self.dim = dim + + dim = 96 + + args = Args(image_size, patch_size, token_size, dim) + + # Create test input [B, C, H, W] + batch_size = 3 + input_tensor = torch.randn(batch_size, 3, image_size, image_size) + + # Create PyTorch reference + ref_layer = TileSelection(args, num_cls) + ref_layer.eval() + + with torch.no_grad(): + ref_output = ref_layer(input_tensor) + + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_layer, + custom_preprocessor=create_tile_selection_preprocessor(device, dim, weight_dtype), + device=device, + ) + + # Create TTNN implementation + tt_layer = TTTileSelection(device=device, parameters=parameters, args=args, num_cls=num_cls, dtype=input_dtype) + + # Convert input to TTNN + tt_input = ttnn.from_torch(input_tensor, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + + # Run TTNN implementation + tt_output = tt_layer(tt_input) + + # Convert outputs back to torch for comparison + tt_mask_3 = tt2torch_tensor(tt_output) + + # Compare outputs with appropriate PCC thresholds + does_pass_3, pcc_message_3 = check_with_pcc(ref_output[0], tt_mask_3, 0.97) + logger.info(f"PCC: {pcc_message_3}") + + if does_pass_3: + logger.info("TileSelection Passed!") + else: + logger.warning("TileSelection Failed!") + + assert does_pass_3, f"PCC check failed: {pcc_message_3}" diff --git a/models/experimental/SSR/tests/tile_selection/test_window_attn.py b/models/experimental/SSR/tests/tile_selection/test_window_attn.py new file mode 100644 index 000000000000..907d165890e2 --- /dev/null +++ b/models/experimental/SSR/tests/tile_selection/test_window_attn.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest + +import ttnn + +from loguru import logger + +from models.experimental.SSR.reference.SSR.model.net_blocks import WindowAttention +from timm.models.layers import to_2tuple + +from models.experimental.SSR.tt.tile_selection import TTWindowAttention +from ttnn.model_preprocessing import preprocess_model_parameters, preprocess_linear_bias, preprocess_linear_weight +from models.utility_functions import tt2torch_tensor +from tests.ttnn.utils_for_testing import check_with_pcc + + +def create_window_attention_preprocessor(device, weight_dtype=ttnn.bfloat16): + def custom_preprocessor(torch_model, name, ttnn_module_args): + parameters = {} + if hasattr(torch_model, "qkv"): # WindowAttention model + parameters["qkv"] = {} + parameters["proj"] = {} + parameters["qkv"]["weight"] = preprocess_linear_weight(torch_model.qkv.weight, dtype=weight_dtype) + parameters["qkv"]["bias"] = preprocess_linear_bias(torch_model.qkv.bias, dtype=weight_dtype) + + # Preprocess relative position bias + relative_position_bias = torch_model.relative_position_bias_table[ + torch_model.relative_position_index.view(-1) + ].view( + torch_model.window_size[0] * torch_model.window_size[1], + torch_model.window_size[0] * torch_model.window_size[1], + -1, + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + parameters["relative_position_bias"] = ttnn.from_torch( + relative_position_bias, dtype=weight_dtype, layout=ttnn.TILE_LAYOUT + ) + + parameters["proj"]["weight"] = preprocess_linear_weight(torch_model.proj.weight, dtype=weight_dtype) + parameters["proj"]["bias"] = preprocess_linear_bias(torch_model.proj.bias, dtype=weight_dtype) + + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "input_shape, window_size, num_heads, input_resolution", + ( + ((1083, 49, 96), (7, 7), 3, (128, 128)), + ((1083, 49, 96), (7, 7), 3, None), + ((300, 49, 192), (7, 7), 3, (64, 64)), + ((300, 49, 192), (7, 7), 3, None), + ((75, 49, 384), (7, 7), 3, (32, 32)), + ((75, 49, 384), (7, 7), 3, None), + ((27, 49, 768), (7, 7), 3, (16, 16)), + ((27, 49, 768), (7, 7), 3, None), + ((12, 49, 1536), (7, 7), 3, (8, 8)), + ((12, 49, 1536), (7, 7), 3, None), + ), +) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("weight_dtype", [ttnn.bfloat8_b]) +def test_window_attn(device, input_shape, window_size, num_heads, input_resolution, input_dtype, weight_dtype): + x = torch.randn(input_shape) + + qkv_bias = True + qk_scale = None + attn_drop = 0.0 + proj_drop = 0.0 + dim = input_shape[-1] + + mask_shape_map = { + (128, 128): (361, 49, 49), + (64, 64): (100, 49, 49), + (32, 32): (25, 49, 49), + (16, 16): (9, 49, 49), + (8, 8): (4, 49, 49), + } + + mask = None + if input_resolution is not None: + mask_shape = mask_shape_map[input_resolution] + mask = torch.zeros(mask_shape) + + ref_layer = WindowAttention( + dim, + window_size=to_2tuple(window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + ref_output = ref_layer(x, mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: ref_layer, + custom_preprocessor=create_window_attention_preprocessor(device, weight_dtype), + device=device, + ) + tt_layer = TTWindowAttention( + parameters=parameters, + device=device, + dim=dim, + window_size=window_size, + num_heads=num_heads, + dtype=input_dtype, + ) + tt_input = ttnn.from_torch(x, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + tt_input = ttnn.to_memory_config(tt_input, ttnn.L1_MEMORY_CONFIG) + tt_mask = None + if mask is not None: + tt_mask = ttnn.from_torch(mask, device=device, layout=ttnn.TILE_LAYOUT, dtype=input_dtype) + tt_output = tt_layer(tt_input, tt_mask) + tt_torch_output = tt2torch_tensor(tt_output) + + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.99) + logger.info(f"PCC: {pcc_message}") + + if does_pass: + logger.info("WindowAttn Passed!") + else: + logger.error("WindowAttn Failed!") + + assert does_pass, f"PCC check failed: {pcc_message}" diff --git a/models/experimental/SSR/tt/__init__.py b/models/experimental/SSR/tt/__init__.py new file mode 100644 index 000000000000..0125eb131099 --- /dev/null +++ b/models/experimental/SSR/tt/__init__.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from .common import TTMlp +from .tile_selection import ( + TTWindowAttention, + TTSwinTransformerBlock, + TTPatchEmbed, + TTPatchMerging, + TTBasicLayer, + TTMaskTokenInference, +) +from .tile_refinement import ( + TTWindowAttentionTR, + TTAttenBlocks, + TTHAB, + TTCAB, + TTWindowAttentionTR, + TTChannelAttention, + TTOCAB, + TTPatchUnEmbed, + TTPatchEmbed, +) + +__all__ = [ + "TTMlp", + "TTWindowAttention", + "TTSwinTransformerBlock", + "TTPatchEmbed", + "TTPatchUnEmbed", + "TTPatchMerging", + "TTBasicLayer", + "TTMaskTokenInference", + "TTPatchUnEmbed", + "TTWindowAttentionTR", +] diff --git a/models/experimental/SSR/tt/common/__init__.py b/models/experimental/SSR/tt/common/__init__.py new file mode 100644 index 000000000000..dea7f3db2b27 --- /dev/null +++ b/models/experimental/SSR/tt/common/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from .mlp import TTMlp + +__all__ = ["TTMlp"] diff --git a/models/experimental/SSR/tt/common/mlp.py b/models/experimental/SSR/tt/common/mlp.py new file mode 100644 index 000000000000..b45aeefba7b9 --- /dev/null +++ b/models/experimental/SSR/tt/common/mlp.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class TTMlp(LightweightModule): + def __init__( + self, device, in_features, hidden_features=None, out_features=None, parameters=None, dtype=ttnn.bfloat16 + ): + self.device = device + + self.in_features = in_features + self.hidden_features = hidden_features + self.out_features = out_features + self.dtype = dtype + + # Initialize weights and biases based on available inputs + # Use preprocessed parameters + self.fc1_weight = parameters["fc1"]["weight"] + self.fc1_bias = parameters["fc1"]["bias"] + self.fc2_weight = parameters["fc2"]["weight"] + self.fc2_bias = parameters["fc2"]["bias"] + + def forward(self, x): + if x.memory_config().buffer_type != ttnn.BufferType.L1: + x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG) + x = ttnn.linear( + x, + self.fc1_weight, + bias=self.fc1_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(y=8, x=8), + activation="gelu", + dtype=self.dtype, + ) + + x = ttnn.linear( + x, + self.fc2_weight, + bias=self.fc2_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(y=8, x=8), + dtype=self.dtype, + ) + + return x diff --git a/models/experimental/SSR/tt/ssr.py b/models/experimental/SSR/tt/ssr.py new file mode 100644 index 000000000000..92df42b19060 --- /dev/null +++ b/models/experimental/SSR/tt/ssr.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.experimental.SSR.tt.tile_refinement.tile_refinement import TTTileRefinement +from models.experimental.SSR.tt.tile_selection.tile_selection import TTTileSelection +from models.experimental.SSR.tt.tile_refinement.upsample import TTUpsample + + +def window_partition_ttnn(x, window_size): + """TTNN implementation of window partitioning""" + b, h, w, c = x.shape + + # Reshape: (b, h, w, c) -> (b, h//ws, ws, w//ws, ws, c) + x = ttnn.reshape(x, (b, h // window_size, window_size, w // window_size, window_size, c)) + + # Permute: (0, 1, 3, 2, 4, 5) -> group windows together + x = ttnn.permute(x, (0, 1, 3, 2, 4, 5)) + + # Final reshape to get windows + x = ttnn.reshape(x, (-1, window_size, window_size, c)) + + return x + + +def window_reverse_ttnn(windows, window_size, h, w): + """TTNN implementation of window reverse""" + b = int(windows.shape[0] / (h * w / window_size / window_size)) + + # Reshape windows back to grid + windows = ttnn.reshape( + windows, + (b, h // window_size, w // window_size, window_size, window_size, -1), + ) + + # Permute back to original order + windows = ttnn.permute(windows, (0, 1, 3, 2, 4, 5)) + + # Final reshape to original spatial dimensions + windows = ttnn.reshape(windows, (b, h, w, -1)) + + return windows + + +class TTSSR(LightweightModule): + """TTNN Super-Resolution Module + + Feeds positive tiles to TR Module, negative tiles to conv layers, + then reconstructs them together. + """ + + def __init__(self, device, parameters, args, num_cls, depth, num_heads, memory_config=None): + super().__init__() + + self.device = device + self.parameters = parameters + self.memory_config = memory_config or ttnn.DRAM_MEMORY_CONFIG + + # Initialize sub-modules using existing TTNN implementations + self.select_model = TTTileSelection( + device=device, + parameters=parameters.select_model, + args=args, + num_cls=num_cls, + memory_config=self.memory_config, + ) + + self.sr_model = TTTileRefinement( + device=device, + parameters=parameters.sr_model, + upscale=4, + img_size=64, + window_size=16, + img_range=1.0, + depths=depth, + embed_dim=180, + num_heads=num_heads, + mlp_ratio=2, + upsampler="pixelshuffle", + memory_config=self.memory_config, + ) + + # Initialize upsample module + self.upsample = TTUpsample(scale=4, num_feat=64, device=device) + + # Store conv configuration for memory optimization + self.conv_config = ttnn.Conv2dConfig( + weights_dtype=ttnn.bfloat16, + activation="", + output_layout=ttnn.TILE_LAYOUT, + deallocate_activation=True, + reallocate_halo_output=True, + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ) + self.conv_before_upsample_conv_config = ttnn.Conv2dConfig( + weights_dtype=ttnn.bfloat16, + activation="", + output_layout=ttnn.TILE_LAYOUT, + deallocate_activation=True, + reallocate_halo_output=True, + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ) + + self.compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + + def forward(self, x): + """Forward pass through SSR module""" + B, C, H, W = x.shape + + # Get tile selection features + patch_fea3 = self.select_model(x) + + # Calculate selection threshold (top 25%) + patch_fea3_flat = ttnn.reshape(patch_fea3, (-1,)) + # Convert to torch for quantile calculation + patch_fea3_torch = ttnn.to_torch(patch_fea3_flat) + threshold = torch.quantile(patch_fea3_torch.to(torch.float32), 0.75) + + # Create selection mask + pi_prime = patch_fea3_torch > threshold + pi_prime = pi_prime.view(-1) + + # Window partition the input image + x_torch = x + x_torch = ttnn.permute(x, (0, 2, 3, 1)) + patch_x_torch = window_partition_ttnn( + x_torch, + window_size=H // 4, + ) + + # Feature extraction for each patch + lr_fea_list = [] + + for i in range(B * 16): + patch_input = ttnn.unsqueeze(patch_x_torch[i], 0) # 1, 3, H/4, W/4 + if pi_prime[i] == 1: + posX, fea = self.sr_model(ttnn.permute(patch_input, (0, 3, 1, 2))) + fea = ttnn.from_device(fea) # Move to host + fea = ttnn.to_dtype(fea, ttnn.bfloat16) # Convert dtype + fea = ttnn.to_device(fea, device=self.device) # Move back to device + lr_fea_list.append(fea) + else: + # Use simple conv for negative tiles + fea = ttnn.conv2d( + input_tensor=patch_input, + weight_tensor=self.parameters.conv_first.weight, + bias_tensor=self.parameters.conv_first.bias, + in_channels=3, + out_channels=180, + device=self.device, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + batch_size=patch_input.shape[0], + input_height=patch_input.shape[1], + input_width=patch_input.shape[2], + memory_config=self.memory_config, + conv_config=self.conv_config, + ) + fea = ttnn.reshape(fea, [1, 64, 64, 180]) + fea = ttnn.from_device(fea) # Move to host + fea = ttnn.to_dtype(fea, ttnn.bfloat16) # Convert dtype + fea = ttnn.to_device(fea, device=self.device) # Move back to device + lr_fea_list.append(fea) + + # Concatenate features + lr_fea = ttnn.concat(lr_fea_list, dim=0) + + # Window reverse to reconstruct full feature map + lr_fea = window_reverse_ttnn( + lr_fea, + window_size=H // 4, + h=H, + w=W, + ) + + slice_config = ttnn.Conv2dSliceConfig(slice_type=ttnn.Conv2dSliceHeight, num_slices=4) + sr_fea = ttnn.conv2d( + input_tensor=lr_fea, + weight_tensor=self.parameters.conv_before_upsample.weight, + bias_tensor=self.parameters.conv_before_upsample.bias, + in_channels=180, + out_channels=64, + device=self.device, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + batch_size=lr_fea.shape[0], + input_height=lr_fea.shape[1], + input_width=lr_fea.shape[2], + dtype=ttnn.bfloat16, + return_output_dim=False, + return_weights_and_bias=False, + slice_config=slice_config, + ) + sr_fea = ttnn.reshape(sr_fea, [B, 256, 256, 64]) + + # LeakyReLU activation + sr_fea = ttnn.leaky_relu(sr_fea, negative_slope=0.01, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + # Upsample + sr_fea = self.upsample(sr_fea, self.parameters.upsample) + + # Final convolution + slice_config = ttnn.Conv2dSliceConfig(slice_type=ttnn.Conv2dSliceHeight, num_slices=4) + sr = ttnn.conv2d( + input_tensor=sr_fea, + weight_tensor=self.parameters.conv_last.weight, + bias_tensor=self.parameters.conv_last.bias, + in_channels=64, + out_channels=3, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + device=self.device, + batch_size=sr_fea.shape[0], + input_height=sr_fea.shape[1], + input_width=sr_fea.shape[2], + dtype=ttnn.bfloat16, + return_output_dim=False, + return_weights_and_bias=False, + slice_config=slice_config, + ) + + sr = ttnn.reshape(sr, [B, 1024, 1024, 3]) + + return sr, patch_fea3 + + +class TTSSR_wo_conv(LightweightModule): + def __init__(self, device, parameters, args, num_cls, depth, num_heads, memory_config=None, dtype=ttnn.bfloat16): + super().__init__() + self.device = device + self.parameters = parameters + self.memory_config = memory_config or ttnn.DRAM_MEMORY_CONFIG + self.dtype = dtype + + # Only need select_model and sr_model - no conv layers + self.select_model = TTTileSelection( + device=device, + parameters=parameters.select_model, + args=args, + num_cls=num_cls, + memory_config=self.memory_config, + dtype=dtype, + ) + + self.sr_model = TTTileRefinement( + device=device, + parameters=parameters.sr_model, + upscale=4, + img_size=64, + window_size=16, + img_range=1.0, + depths=depth, + embed_dim=180, + num_heads=num_heads, + mlp_ratio=2, + upsampler="pixelshuffle", + memory_config=self.memory_config, + dtype=dtype, + ) + + def forward(self, x): + B, C, H, W = x.shape + + # Same tile selection logic + patch_fea3 = self.select_model(x) + + # Calculate selection threshold (top 25%) + patch_fea3_flat = ttnn.reshape(patch_fea3, (-1,)) + # Convert to torch for quantile calculation + patch_fea3_torch = ttnn.to_torch(patch_fea3_flat) + threshold = torch.quantile(patch_fea3_torch.to(torch.float32), 0.75) + pi_prime = patch_fea3_torch > threshold + + # Window partition + x_torch = x + x_torch = ttnn.permute(x, (0, 2, 3, 1)) + patch_x = window_partition_ttnn(x_torch, window_size=H // 4) + + # Process each patch + sr_patches = [] + for i in range(B * 16): + patch_input = ttnn.unsqueeze(patch_x[i], 0) + + if pi_prime[i] == 1: + # Use SR model for positive tiles + posX, _ = self.sr_model(ttnn.permute(patch_input, (0, 3, 1, 2))) + sr_patches.append(posX) + else: + # Move tensor to host and convert to torch + patch_host = ttnn.to_torch(patch_input) # Shape: (1, H/4, W/4, C) + + # Convert to NCHW format for PyTorch upsample + patch_nchw = patch_host.permute(0, 3, 1, 2) # (1, H/4, W/4, C) -> (1, C, H/4, W/4) + + # Use PyTorch's upsample (bicubic like the reference) + negX_torch = torch.nn.functional.upsample(patch_nchw, scale_factor=4, mode="bicubic") + + # Convert back to NHWC and to TTNN tensor + negX_torch = negX_torch.permute(0, 2, 3, 1) # Back to (1, H, W, C) + negX = ttnn.from_torch(negX_torch, device=self.device, dtype=self.dtype, layout=ttnn.TILE_LAYOUT) + sr_patches.append(negX) + + ttnn.deallocate(patch_x) + + # Concatenate and reconstruct + sr = ttnn.concat(sr_patches, dim=0, memory_config=self.memory_config) + return sr, patch_fea3 diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/CAB/CAB.py b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/CAB/CAB.py new file mode 100644 index 000000000000..8fb61d2329c4 --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/CAB/CAB.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule +from .channel_attention import TTChannelAttention + + +class TTCAB(LightweightModule): + def __init__( + self, device, parameters, num_feat, compress_ratio=3, squeeze_factor=30, memory_config=None, dtype=ttnn.bfloat16 + ): + super().__init__() + + self.device = device + self.memory_config = ttnn.L1_MEMORY_CONFIG + self.num_feat = num_feat + self.compress_ratio = compress_ratio + self.squeeze_factor = squeeze_factor + self.dtype = dtype + + # Extract preprocessed parameters for convolutions + self.conv1_weight = parameters["conv1"]["weight"] + self.conv1_bias = parameters["conv1"]["bias"] + self.conv2_weight = parameters["conv2"]["weight"] + self.conv2_bias = parameters["conv2"]["bias"] + + # Initialize channel attention module + self.channel_attention = TTChannelAttention( + device=device, + parameters=parameters["channel_attention"], + num_feat=num_feat, + squeeze_factor=squeeze_factor, + memory_config=memory_config, + dtype=dtype, + ) + + def forward(self, x): + # Store original input shape for convolutions + batch_size, height, width, channels = x.shape + conv_config = ttnn.Conv2dConfig( + weights_dtype=ttnn.bfloat16, + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + deallocate_activation=True, + output_layout=ttnn.TILE_LAYOUT, + activation="gelu", + ) + # First 3x3 convolution (compression) + x = ttnn.conv2d( + input_tensor=x, + weight_tensor=self.conv1_weight, + bias_tensor=self.conv1_bias, + device=self.device, + in_channels=self.num_feat, + out_channels=self.num_feat // self.compress_ratio, + batch_size=batch_size, + input_height=height, + input_width=width, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + memory_config=self.memory_config, + conv_config=conv_config, + compute_config=ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ), + dtype=self.dtype, + ) + + # Reshape from flattened conv output back to spatial format + x = ttnn.reshape(x, [batch_size, height, width, self.num_feat // self.compress_ratio]) + + conv_config = ttnn.Conv2dConfig( + weights_dtype=ttnn.bfloat16, + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + deallocate_activation=True, + output_layout=ttnn.TILE_LAYOUT, + ) + # Second 3x3 convolution (expansion) + x = ttnn.conv2d( + input_tensor=x, + weight_tensor=self.conv2_weight, + bias_tensor=self.conv2_bias, + device=self.device, + in_channels=self.num_feat // self.compress_ratio, + out_channels=self.num_feat, + batch_size=batch_size, + input_height=height, + input_width=width, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + memory_config=self.memory_config, + compute_config=ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ), + conv_config=conv_config, + dtype=self.dtype, + ) + + # Reshape from flattened conv output back to spatial format + x = ttnn.reshape(x, [batch_size, height, width, self.num_feat], memory_config=self.memory_config) + + # Apply channel attention + x = self.channel_attention(x) + + return x diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/CAB/__init__.py b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/CAB/__init__.py new file mode 100644 index 000000000000..04d04f23362e --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/CAB/__init__.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from .CAB import TTCAB +from .channel_attention import TTChannelAttention + +__all__ = [ + "TTCAB", + "TTChannelAttention", +] diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/CAB/channel_attention.py b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/CAB/channel_attention.py new file mode 100644 index 000000000000..8eb72bb4affd --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/CAB/channel_attention.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class TTChannelAttention(LightweightModule): + def __init__(self, device, parameters, num_feat, squeeze_factor=16, memory_config=None, dtype=ttnn.bfloat16): + super().__init__() + + self.device = device + self.memory_config = ttnn.L1_MEMORY_CONFIG + self.num_feat = num_feat + self.squeeze_factor = squeeze_factor + self.dtype = dtype + + # Extract preprocessed parameters + self.conv1_weight = parameters["conv1"]["weight"] + self.conv1_bias = parameters["conv1"]["bias"] + self.conv2_weight = parameters["conv2"]["weight"] + self.conv2_bias = parameters["conv2"]["bias"] + + def forward(self, x): + original_x = x + original_shape = x.shape + + if x.memory_config().buffer_type != ttnn.BufferType.L1: + x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG, dtype=self.dtype) + x = ttnn.global_avg_pool2d(x, memory_config=self.memory_config, dtype=self.dtype) + + if original_shape[-1] == 180: + x = ttnn.slice( + x, + starts=[0, 0, 0, 0], # Start indices for each dimension + ends=[original_shape[0], 1, 1, 180], # End indices - slice to 180 in last dim + steps=[1, 1, 1, 1], # Step size for each dimension + ) + + x = ttnn.linear( + x, + self.conv1_weight, + bias=self.conv1_bias, + memory_config=self.memory_config, + activation="relu", + dtype=self.dtype, + core_grid=ttnn.CoreGrid(y=8, x=8), + ) + + x = ttnn.linear( + x, + self.conv2_weight, + bias=self.conv2_bias, + memory_config=self.memory_config, + dtype=self.dtype, + core_grid=ttnn.CoreGrid(y=8, x=8), + ) + + # Sigmoid activation + x = ttnn.sigmoid(x) + + batch_size, height, width, channels = original_shape + attention_weights = ttnn.reshape(x, [batch_size, 1, 1, channels]) + attention_weights = ttnn.repeat(attention_weights, [1, height, width, 1], memory_config=self.memory_config) + + # Element-wise multiplication with original input + output = ttnn.multiply(original_x, attention_weights, memory_config=self.memory_config, dtype=self.dtype) + + return output diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/HAB.py b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/HAB.py new file mode 100644 index 000000000000..66084b2d0eb7 --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/HAB.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule + +from .CAB import TTCAB +from models.experimental.SSR.tt.common.mlp import TTMlp +from .window_attn_tr import TTWindowAttentionTR + + +class TTHAB(LightweightModule): + def __init__( + self, + device, + parameters, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + memory_config=None, + dtype=ttnn.bfloat16, + ): + super().__init__() + self.device = device + self.memory_config = memory_config or ttnn.L1_MEMORY_CONFIG + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.dtype = dtype + + if min(self.input_resolution) <= self.window_size: + self.shift_size = 0 + self.window_size = min(self.input_resolution) + + # Extract preprocessed parameters + self.norm1_weight = parameters["norm1"]["weight"] + self.norm1_bias = parameters["norm1"]["bias"] + self.norm2_weight = parameters["norm2"]["weight"] + self.norm2_bias = parameters["norm2"]["bias"] + self.conv_scale = parameters.get("conv_scale", 0.01) + + # Initialize sub-modules + self.attn = TTWindowAttentionTR( + device=device, + parameters=parameters["attn"], + dim=dim, + window_size=(self.window_size, self.window_size), + num_heads=num_heads, + memory_config=memory_config, + dtype=dtype, + ) + + self.conv_block = TTCAB( + device=device, parameters=parameters["conv_block"], num_feat=dim, memory_config=memory_config, dtype=dtype + ) + + self.mlp = TTMlp( + device=device, + in_features=dim, + hidden_features=int(dim * mlp_ratio), + parameters=parameters["mlp"], + dtype=dtype, + ) + + def forward(self, x, x_size, rpi_sa, attn_mask): + h, w = x_size + b, seq_len, c = x.shape + if x.memory_config().buffer_type != ttnn.BufferType.L1: + x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG, dtype=self.dtype) + shortcut = x + + # Layer norm 1 + x = ttnn.layer_norm(x, weight=self.norm1_weight, bias=self.norm1_bias) + + # Reshape to spatial format for conv and attention + x = ttnn.reshape(x, [b, h, w, c]) + + # Convolutional branch + conv_x = self.conv_block(x) + conv_x = ttnn.reshape(conv_x, [b, h * w, c]) + conv_x = ttnn.multiply(conv_x, self.conv_scale, dtype=self.dtype) + + # Attention branch - handle cyclic shifttt-metal + if self.shift_size > 0: + # Cyclic shift + shifted_x = ttnn.roll(x, [-self.shift_size, -self.shift_size], [1, 2]) + current_attn_mask = attn_mask + else: + shifted_x = x + current_attn_mask = None + + # Window partition + if shifted_x.memory_config().buffer_type != ttnn.BufferType.L1: + shifted_x = ttnn.to_memory_config(shifted_x, self.memory_config, dtype=self.dtype) + x_windows = self._window_partition(shifted_x, self.window_size) + x_windows = ttnn.reshape(x_windows, [-1, self.window_size * self.window_size, c]) + + # Window attention + attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=current_attn_mask) + + # Window reverse + attn_windows = ttnn.reshape(attn_windows, [-1, self.window_size, self.window_size, c]) + shifted_x = self._window_reverse(attn_windows, self.window_size, h, w) + + # Reverse cyclic shift + if self.shift_size > 0: + attn_x = ttnn.roll(shifted_x, [self.shift_size, self.shift_size], [1, 2]) + else: + attn_x = shifted_x + + if attn_x.memory_config().buffer_type != ttnn.BufferType.L1: + attn_x = ttnn.to_memory_config(attn_x, ttnn.L1_MEMORY_CONFIG, dtype=self.dtype) + attn_x = ttnn.reshape(attn_x, [b, h * w, c]) + + # First residual connection + x = ttnn.add(shortcut, attn_x, dtype=self.dtype) + x = ttnn.add(x, conv_x, dtype=self.dtype) + + # MLP branch + x_norm = ttnn.layer_norm(x, weight=self.norm2_weight, bias=self.norm2_bias) + mlp_out = self.mlp(x_norm) + + # Second residual connection + x = ttnn.add(x, mlp_out) + + return x + + def _window_partition(self, x, window_size): + """Partition into non-overlapping windows""" + B, H, W, C = x.shape + num_windows = (H // window_size) * (W // window_size) + return ttnn.reshape(x, [B * num_windows, window_size, window_size, C], memory_config=self.memory_config) + + def _window_reverse(self, windows, window_size, H, W): + B = windows.shape[0] // (H * W // window_size // window_size) + return ttnn.reshape(windows, [B, H, W, -1], memory_config=self.memory_config) diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/__init__.py b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/__init__.py new file mode 100644 index 000000000000..3aec76c5613f --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/__init__.py @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from .HAB import TTHAB +from .window_attn_tr import TTWindowAttentionTR +from .CAB import TTCAB, TTChannelAttention + +__all__ = [ + "TTHAB", + "TTCAB", + "TTWindowAttentionTR", + "TTChannelAttention", +] diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/window_attn_tr.py b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/window_attn_tr.py new file mode 100644 index 000000000000..9e58200d514f --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/HAB/window_attn_tr.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +from models.common.lightweightmodule import LightweightModule + + +class TTWindowAttentionTR(LightweightModule): + def __init__(self, device, parameters, dim, window_size, num_heads, memory_config=None, dtype=ttnn.bfloat16): + super().__init__() + self.device = device + self.memory_config = memory_config or ttnn.L1_MEMORY_CONFIG + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.dtype = dtype + + # Extract preprocessed parameters + self.qkv_weight = parameters["qkv"]["weight"] + self.qkv_bias = parameters["qkv"]["bias"] if "bias" in parameters["qkv"] else None + self.proj_weight = parameters["proj"]["weight"] + self.proj_bias = parameters["proj"]["bias"] if "bias" in parameters["proj"] else None + self.relative_position_bias = parameters["relative_position_bias"] + + # Scale factor + self.scale = self.head_dim**-0.5 + + def forward(self, x, rpi, mask=None): + b_, n, c = x.shape + if x.memory_config().buffer_type != ttnn.BufferType.L1: + x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG) + self.memory_config = ttnn.L1_MEMORY_CONFIG if b_ * n * c < 1_100_000 else ttnn.DRAM_MEMORY_CONFIG + qkv = ttnn.linear( + x, + self.qkv_weight, + bias=self.qkv_bias, + memory_config=self.memory_config, + dtype=self.dtype, + core_grid=ttnn.CoreGrid(y=8, x=8), + ) + ttnn.deallocate(x) + ( + q, + k, + v, + ) = ttnn.transformer.split_query_key_value_and_split_heads( + qkv, memory_config=ttnn.L1_MEMORY_CONFIG, num_heads=self.num_heads + ) + + ttnn.deallocate(qkv) + + # Remove the first dimension + q = ttnn.squeeze(q, 0) + k = ttnn.squeeze(k, 0) + v = ttnn.squeeze(v, 0) + + # Scale Q + q = ttnn.multiply(q, self.scale, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.dtype) + + attn = ttnn.matmul( + q, + k, + compute_kernel_config=ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + ), + memory_config=self.memory_config, + core_grid=ttnn.CoreGrid(y=8, x=8), + dtype=self.dtype, + ) + ttnn.deallocate(q) + ttnn.deallocate(k) + + attn = ttnn.add(attn, self.relative_position_bias, memory_config=self.memory_config, dtype=self.dtype) + + # Apply mask if provided + if mask is not None: + nw = mask.shape[0] + attn = ttnn.reshape(attn, [b_ // nw, nw, self.num_heads, n, n]) + mask_expanded = ttnn.unsqueeze(ttnn.unsqueeze(mask, 1), 0) + attn = ttnn.add(attn, mask_expanded, dtype=self.dtype) + attn = ttnn.reshape(attn, [-1, self.num_heads, n, n]) + + # Softmax + attn = ttnn.softmax(attn, dim=-1, memory_config=self.memory_config) + + # Apply attention to values + x = ttnn.matmul( + attn, + v, + compute_kernel_config=ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + ), + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(y=8, x=8), + dtype=self.dtype, + ) # [b_, num_heads, n, head_dim] + + output_tensor = ttnn.transformer.concatenate_heads(x) + + if self.proj_weight.shape[-1] != output_tensor.shape[-1]: + head_size = self.proj_weight.shape[-1] // self.num_heads + padded_head_size = output_tensor.shape[-1] // self.num_heads + output_tensor = ttnn.to_torch(output_tensor) + + # Remove the padding + output_tensor = torch.cat( + [chunk[..., :head_size] for chunk in torch.split(output_tensor, padded_head_size, dim=-1)], dim=-1 + ) + x = ttnn.from_torch( + output_tensor, + device=self.device, + dtype=self.dtype, + memory_config=self.memory_config, + layout=ttnn.TILE_LAYOUT, + ) + + x = ttnn.linear( + x, + self.proj_weight, + bias=self.proj_bias, + memory_config=self.memory_config, + core_grid=ttnn.CoreGrid(y=8, x=8), + dtype=self.dtype, + ) + + return x diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/OCAB/OCAB.py b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/OCAB/OCAB.py new file mode 100644 index 000000000000..f4f6eaba6a48 --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/OCAB/OCAB.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class TTOCAB(LightweightModule): + def __init__( + self, + device, + dim, + input_resolution, + window_size, + overlap_ratio, + num_heads, + parameters, + qkv_bias=True, + qk_scale=None, + mlp_ratio=2, + dtype=ttnn.bfloat16, + depth=[6, 6, 6, 6, 6, 6], + ): + super().__init__() + + self.device = device + self.dim = dim + self.input_resolution = input_resolution + self.window_size = window_size + self.num_heads = num_heads + self.overlap_ratio = overlap_ratio + self.dtype = dtype + self.depth = depth + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.overlap_win_size = int(window_size * overlap_ratio) + window_size + + # Extract preprocessed parameters + self.norm1_weight = parameters["norm1"]["weight"] + self.norm1_bias = parameters["norm1"]["bias"] + + self.qkv_weight = parameters["qkv"]["weight"] + self.qkv_bias = parameters["qkv"]["bias"] + + self.relative_position_bias_table = parameters["relative_position_bias_table"] + + self.proj_weight = parameters["proj"]["weight"] + self.proj_bias = parameters["proj"]["bias"] + + self.norm2_weight = parameters["norm2"]["weight"] + self.norm2_bias = parameters["norm2"]["bias"] + + self.mlp_fc1_weight = parameters["mlp"]["fc1"]["weight"] + self.mlp_fc1_bias = parameters["mlp"]["fc1"]["bias"] + self.mlp_fc2_weight = parameters["mlp"]["fc2"]["weight"] + self.mlp_fc2_bias = parameters["mlp"]["fc2"]["bias"] + + def forward(self, x, x_size, rpi): + h, w = x_size + b, _, c = x.shape + shortcut = x + + # Layer normalization + x = ttnn.layer_norm(x, weight=self.norm1_weight, bias=self.norm1_bias, memory_config=ttnn.L1_MEMORY_CONFIG) + + # Fused QKV projection - use single linear operation + qkv = ttnn.linear( + x, + self.qkv_weight, + bias=self.qkv_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=self.dtype, + core_grid=ttnn.CoreGrid(y=8, x=8), + ) + # Use transformer function for QKV splitting + query, key, value = ttnn.transformer.split_query_key_value_and_split_heads( + qkv, memory_config=ttnn.DRAM_MEMORY_CONFIG, num_heads=self.num_heads, transpose_key=False + ) + ttnn.deallocate(qkv) + + sdpa_program_config = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=[8, 7], + q_chunk_size=512, + k_chunk_size=512, + exp_approx_mode=False, + ) + compute_kernel_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + + # Use optimized scaled dot product attention + attention_output = ttnn.transformer.scaled_dot_product_attention( + query, + key, + value, + is_causal=False, + scale=self.scale, + program_config=sdpa_program_config if self.depth == [6, 6, 6, 6, 6, 6] else None, + compute_kernel_config=compute_kernel_config if self.depth == [6, 6, 6, 6, 6, 6] else None, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + # Deallocate intermediate tensors + ttnn.deallocate(query) + ttnn.deallocate(key) + ttnn.deallocate(value) + # Use transformer function for head concatenation + context_layer = ttnn.transformer.concatenate_heads( + attention_output, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + ttnn.deallocate(attention_output) + + if context_layer.shape[-1] != self.dim: + # remove padding + context_layer = ttnn.to_torch(context_layer)[..., : self.dim] # slice to 180 and remove padding + context_layer = ttnn.from_torch( + context_layer, + device=self.device, + dtype=self.dtype, + memory_config=ttnn.L1_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + x = ttnn.reshape(context_layer, (b, h * w, self.dim), memory_config=ttnn.L1_MEMORY_CONFIG) + + # Output projection and residual + x = ttnn.linear( + x, + self.proj_weight, + bias=self.proj_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(y=8, x=8), + dtype=self.dtype, + ) + x = ttnn.add(x, shortcut, dtype=self.dtype) + + x = ttnn.layer_norm(x, weight=self.norm2_weight, bias=self.norm2_bias, memory_config=ttnn.L1_MEMORY_CONFIG) + + mlp_out = ttnn.linear( + x, + self.mlp_fc1_weight, + bias=self.mlp_fc1_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=self.dtype, + core_grid=ttnn.CoreGrid(y=8, x=8), + activation="gelu", + ) + mlp_out = ttnn.linear( + mlp_out, + self.mlp_fc2_weight, + bias=self.mlp_fc2_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(y=8, x=8), + dtype=self.dtype, + ) + + x = ttnn.add(x, mlp_out, dtype=self.dtype) + return x diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/OCAB/__init__.py b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/OCAB/__init__.py new file mode 100644 index 000000000000..41133dcaec87 --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/OCAB/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from .OCAB import TTOCAB + +__all__ = [ + "TTOCAB", +] diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/__init__.py b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/__init__.py new file mode 100644 index 000000000000..d27c9dd53728 --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from .atten_blocks import TTAttenBlocks +from .HAB import TTHAB, TTCAB, TTWindowAttentionTR, TTChannelAttention +from .OCAB import TTOCAB + +__all__ = [ + "TTAttenBlocks", + "TTHAB", + "TTCAB", + "TTWindowAttentionTR", + "TTChannelAttention", + "TTOCAB", +] diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/atten_blocks.py b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/atten_blocks.py new file mode 100644 index 000000000000..a107932a4ade --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/ATTEN_BLK/atten_blocks.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule +from .HAB import TTHAB +from .OCAB import TTOCAB + + +class TTAttenBlocks(LightweightModule): + def __init__( + self, + device, + parameters, + dim, + input_resolution, + depth, + num_heads, + window_size, + compress_ratio, + squeeze_factor, + conv_scale, + overlap_ratio, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + downsample=None, + memory_config=None, + dtype=ttnn.bfloat16, + ): + super().__init__() + self.device = device + self.memory_config = memory_config or ttnn.DRAM_MEMORY_CONFIG + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.dtype = dtype + + # Build HAB blocks + self.blocks = [] + for i in range(depth): + # Calculate shift size: 0 for even indices, window_size // 2 for odd indices + shift_size = 0 if (i % 2 == 0) else window_size // 2 + + hab_block = TTHAB( + device=device, + parameters=parameters["blocks"][i], + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=shift_size, + mlp_ratio=mlp_ratio, + memory_config=memory_config, + dtype=dtype, + ) + self.blocks.append(hab_block) + + # OCAB (Overlapping Cross Attention Block) + self.overlap_attn = TTOCAB( + device=device, + dim=dim, + input_resolution=input_resolution, + window_size=window_size, + overlap_ratio=overlap_ratio, + num_heads=num_heads, + parameters=parameters["overlap_attn"], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + mlp_ratio=mlp_ratio, + dtype=dtype, + depth=depth, + ) + + self.downsample = None + if downsample is not None: + self.downsample = downsample + + def forward(self, x, x_size, params): + """ + Forward pass through all attention blocks + + Args: + x: Input tensor of shape [batch, seq_len, dim] + x_size: Tuple of (height, width) for spatial dimensions + params: Dictionary containing: + - "rpi_sa": Relative position index for self-attention + - "attn_mask": Attention mask for shifted windows + - "rpi_oca": Relative position index for overlapping cross attention + """ + # Process through all HAB blocks + for i, blk in enumerate(self.blocks): + x = blk(x, x_size, params["rpi_sa"], params["attn_mask"]) + + # Apply overlapping cross attention + x = self.overlap_attn(x, x_size, params["rpi_oca"]) + + # Apply downsampling if present + if self.downsample is not None: + x = self.downsample(x) + + return x diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/RHAG.py b/models/experimental/SSR/tt/tile_refinement/RHAG/RHAG.py new file mode 100644 index 000000000000..1df1263879f8 --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/RHAG.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule +from .ATTEN_BLK import TTAttenBlocks + +# from models.experimental.SSR.tt.patch_embed import TTPatchEmbed +from .patch_embed_tile_refinement import TTPatchEmbed +from .patch_unembed import TTPatchUnEmbed + + +class TTRHAG(LightweightModule): + """TTNN Residual Hybrid Attention Group (RHAG). + + Args: + device: TTNN device + parameters: Preprocessed parameters dictionary + dim (int): Number of input channels + input_resolution (tuple[int]): Input resolution + depth (int): Number of blocks + num_heads (int): Number of attention heads + window_size (int): Local window size + compress_ratio (int): Compression ratio for CAB + squeeze_factor (int): Squeeze factor for channel attention + conv_scale (float): Scale factor for conv branch + overlap_ratio (float): Overlap ratio for OCAB + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim + qkv_bias (bool): If True, add a learnable bias to query, key, value + qk_scale (float | None): Override default qk scale + drop (float): Dropout rate + attn_drop (float): Attention dropout rate + drop_path (float | tuple[float]): Stochastic depth rate + downsample: Downsample layer at the end of the layer + img_size (int): Input image size + patch_size (int): Patch size + resi_connection (str): The convolutional block before residual connection + memory_config: TTNN memory configuration + """ + + def __init__( + self, + device, + parameters, + dim, + input_resolution, + depth, + num_heads, + window_size, + compress_ratio, + squeeze_factor, + conv_scale, + overlap_ratio, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + downsample=None, + img_size=224, + patch_size=4, + resi_connection="1conv", + memory_config=None, + dtype=ttnn.bfloat16, + ): + super().__init__() + + self.device = device + self.memory_config = memory_config or ttnn.DRAM_MEMORY_CONFIG + self.dim = dim + self.input_resolution = input_resolution + self.resi_connection = resi_connection + self.dtype = dtype + + # Initialize AttenBlocks (residual_group) + self.residual_group = TTAttenBlocks( + device=device, + parameters=parameters["residual_group"], + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + compress_ratio=compress_ratio, + squeeze_factor=squeeze_factor, + conv_scale=conv_scale, + overlap_ratio=overlap_ratio, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path, + downsample=downsample, + memory_config=memory_config, + dtype=dtype, + ) + + # Initialize convolutional layer for residual connection + if resi_connection == "1conv": + # Extract conv parameters + self.conv_weight = parameters["conv"]["weight"] + self.conv_bias = parameters["conv"]["bias"] + + # Conv2d configuration + self.conv_config = ttnn.Conv2dConfig( + weights_dtype=ttnn.bfloat16, + activation="", + output_layout=ttnn.TILE_LAYOUT, + deallocate_activation=True, + reallocate_halo_output=True, + ) + + # Compute configuration + self.compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + elif resi_connection == "identity": + # Identity connection - no conv layer needed + self.conv_weight = None + self.conv_bias = None + + # Initialize PatchEmbed + self.patch_embed = TTPatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=0, # Set to 0 as in original + embed_dim=dim, + norm_layer=None, + device=device, + parameters=parameters["patch_embed"], + memory_config=memory_config, + ) + + # Initialize PatchUnEmbed + self.patch_unembed = TTPatchUnEmbed( + mesh_device=device, + img_size=img_size, + patch_size=patch_size, + in_chans=0, # Set to 0 as in original + embed_dim=dim, + ) + + def forward(self, x, x_size, params): + """ + Forward pass through RHAG + + Args: + x: Input tensor of shape [batch, seq_len, dim] + x_size: Tuple of (height, width) for spatial dimensions + params: Dictionary containing: + - "rpi_sa": Relative position index for self-attention + - "attn_mask": Attention mask for shifted windows + - "rpi_oca": Relative position index for overlapping cross attention + + Returns: + Output tensor with residual connection + """ + # Store input for residual connection + shortcut = x + + # Pass through residual group (AttenBlocks) + x = self.residual_group(x, x_size, params) + + # Patch unembed: convert from sequence to spatial format + x = self.patch_unembed(x, x_size) + + # Apply convolutional layer + if self.resi_connection == "1conv": + batch_size, embed_dim, height, width = x.shape + x = ttnn.permute(x, (0, 2, 3, 1)) # (batch_size, embed_dim, num_patches) + + # Apply 3x3 convolution with padding=1 + x, [out_height, out_width] = ttnn.conv2d( + input_tensor=x, + weight_tensor=self.conv_weight, + bias_tensor=self.conv_bias, + in_channels=self.dim, + out_channels=self.dim, + device=self.device, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + batch_size=batch_size, + input_height=height, + input_width=width, + conv_config=self.conv_config, + compute_config=self.compute_config, + dtype=self.dtype, + return_output_dim=True, + ) + + x = ttnn.reshape(x, (batch_size, out_height, out_width, self.dim)) + elif self.resi_connection == "identity": + x = ttnn.permute(x, (0, 2, 3, 1)) # (batch_size, embed_dim, num_patches) + + # Patch embed: convert back to sequence format + x = self.patch_embed(x) + + x = ttnn.reshape(x, (x.shape[0], self.input_resolution[0] * self.input_resolution[1], self.dim)) + + # Add residual connection + x = ttnn.add(x, shortcut, dtype=self.dtype) + + return x diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/__init__.py b/models/experimental/SSR/tt/tile_refinement/RHAG/__init__.py new file mode 100644 index 000000000000..3caf5c9d0454 --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from .RHAG import TTRHAG +from .ATTEN_BLK import TTAttenBlocks, TTHAB, TTCAB, TTWindowAttentionTR, TTChannelAttention, TTOCAB +from .patch_embed_tile_refinement import TTPatchEmbed +from .patch_unembed import TTPatchUnEmbed + +__all__ = [ + "TTRHAG", + "TTAttenBlocks", + "TTHAB", + "TTCAB", + "TTWindowAttentionTR", + "TTChannelAttention", + "TTOCAB", + "TTPatchEmbed", + "TTPatchUnEmbed", +] diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/patch_embed_tile_refinement.py b/models/experimental/SSR/tt/tile_refinement/RHAG/patch_embed_tile_refinement.py new file mode 100644 index 000000000000..9882746a1f7f --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/patch_embed_tile_refinement.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class TTPatchEmbed(LightweightModule): + """TTNN Image to Patch Embedding (simplified version) + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer: Normalization layer. Default: None + device: TTNN device + parameters: Preprocessed parameters dictionary + memory_config: TTNN memory configuration + """ + + def __init__( + self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None, + device=None, + parameters=None, + memory_config=None, + ): + super().__init__() + + # Convert to tuples (assuming square images/patches) + self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size + self.patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size + + self.patches_resolution = [self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1]] + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + + self.memory_config = memory_config or ttnn.DRAM_MEMORY_CONFIG + + # Store normalization parameters if provided + if norm_layer is not None and parameters is not None: + self.norm_weight = parameters.get("norm", {}).get("weight") + self.norm_bias = parameters.get("norm", {}).get("bias") + else: + self.norm_weight = None + self.norm_bias = None + + def forward(self, x): + """ + Forward pass through patch embedding + + Args: + x: Input tensor of shape [batch, channels, height, width] + + Returns: + Output tensor of shape [batch, num_patches, embed_dim] + """ + + if x.is_sharded(): + x = ttnn.to_memory_config(x, ttnn.DRAM_MEMORY_CONFIG) + + # Apply normalization if available + if self.norm_weight is not None: + x = ttnn.permute(x, [0, 2, 3, 1]) + x = ttnn.layer_norm(x, weight=self.norm_weight, bias=self.norm_bias, memory_config=self.memory_config) + + return x diff --git a/models/experimental/SSR/tt/tile_refinement/RHAG/patch_unembed.py b/models/experimental/SSR/tt/tile_refinement/RHAG/patch_unembed.py new file mode 100644 index 000000000000..18f665ae62f3 --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/RHAG/patch_unembed.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class TTPatchUnEmbed(LightweightModule): + """Image to Patch Unembedding in TTNN""" + + def __init__(self, mesh_device, img_size=224, patch_size=4, in_chans=3, embed_dim=96): + super().__init__() + + self.embed_dim = embed_dim + + def forward(self, x, x_size): + batch_size = x.shape[0] + + # Transpose from (B, N, C) to (B, C, N) equivalent + x = ttnn.permute(x, (0, 2, 1)) # (batch_size, embed_dim, num_patches) + + # Reshape to spatial dimensions + x = ttnn.reshape(x, (batch_size, self.embed_dim, x_size[0], x_size[1]), memory_config=ttnn.L1_MEMORY_CONFIG) + + return x diff --git a/models/experimental/SSR/tt/tile_refinement/__init__.py b/models/experimental/SSR/tt/tile_refinement/__init__.py new file mode 100644 index 000000000000..c1e88f286d99 --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/__init__.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from .tile_refinement import TTTileRefinement +from .upsample import TTUpsample +from .RHAG import ( + TTRHAG, + TTAttenBlocks, + TTHAB, + TTCAB, + TTWindowAttentionTR, + TTChannelAttention, + TTOCAB, + TTPatchEmbed, + TTPatchUnEmbed, +) + +__all__ = [ + "TTTileRefinement", + "TTPatchEmbedTR", + "TTPatchUnEmbed", + "TTUpsample", + "TTRHAG", + "TTAttenBlocks", + "TTHAB", + "TTCAB", + "TTWindowAttentionTR", + "TTChannelAttention", + "TTOCAB", +] diff --git a/models/experimental/SSR/tt/tile_refinement/tile_refinement.py b/models/experimental/SSR/tt/tile_refinement/tile_refinement.py new file mode 100644 index 000000000000..3b57ea0f1190 --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/tile_refinement.py @@ -0,0 +1,401 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from models.common.lightweightmodule import LightweightModule +from .RHAG import TTPatchEmbed, TTPatchUnEmbed, TTRHAG +from .upsample import TTUpsample + + +class TTHAT(LightweightModule): + """TTNN Hybrid Attention Transformer base class""" + + def __init__( + self, + device, + parameters, + img_size=64, + patch_size=1, + in_chans=3, + embed_dim=96, + depths=(6, 6, 6, 6), + num_heads=(6, 6, 6, 6), + window_size=7, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + overlap_ratio=0.5, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + ape=False, + patch_norm=True, + upscale=2, + img_range=1.0, + upsampler="", + resi_connection="1conv", + memory_config=None, + h=64, + w=64, + dtype=ttnn.bfloat16, + **kwargs, + ): + super().__init__() + + self.device = device + self.parameters = parameters + self.window_size = window_size + self.shift_size = window_size // 2 + self.overlap_ratio = overlap_ratio + self.img_range = img_range + self.upscale = upscale + self.upsampler = upsampler + self.embed_dim = embed_dim + self.num_layers = len(depths) + self.memory_config = ttnn.DRAM_MEMORY_CONFIG + self.h = h + self.w = w + self.ape = ape + self.layers = [] + self.dtype = dtype + num_feat = 64 + + self.patch_embed = TTPatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=0, # Set to 0 as in original + embed_dim=180, + norm_layer=1, + device=device, + parameters=parameters["patch_embed"], + memory_config=memory_config, + ) + + for i_layer in range(self.num_layers): + layer = TTRHAG( + device=device, + parameters=self.parameters[f"layers.{i_layer}"], + dim=embed_dim, + input_resolution=(64, 64), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + overlap_ratio=overlap_ratio, + mlp_ratio=mlp_ratio, + img_size=max(64, 64), + patch_size=4, + resi_connection=resi_connection, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=dtype, + ) + self.layers.append(layer) + + self.patch_unembed = TTPatchUnEmbed( + mesh_device=device, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim + ) + + self.upsample = TTUpsample(upscale, num_feat, device, dtype=dtype) + + # Mean normalization values + if in_chans == 3: + rgb_mean = torch.tensor([0.4488, 0.4371, 0.4040]).view(1, 3, 1, 1) + self.mean = ttnn.from_torch(rgb_mean, dtype=dtype, device=device) + else: + self.mean = ttnn.zeros((1, 1, 1, 1), dtype=dtype, device=device) + + def calculate_mask(self, x_size): + """Calculate attention mask for SW-MSA with proper padding""" + h, w = x_size + + # Calculate padding needed to make dimensions divisible by window_size + pad_h = (self.window_size - h % self.window_size) % self.window_size + pad_w = (self.window_size - w % self.window_size) % self.window_size + + # Use padded dimensions + padded_h = h + pad_h + padded_w = w + pad_w + + # Create mask with padded dimensions + img_mask = torch.zeros((1, padded_h, padded_w, 1)) + + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + + cnt = 0 + for h_slice in h_slices: + for w_slice in w_slices: + img_mask[:, h_slice, w_slice, :] = cnt + cnt += 1 + + # Convert to TTNN tensor + return ttnn.from_torch(img_mask, dtype=ttnn.bfloat16, device=self.device) + + def forward_features(self, x): + """Forward pass through transformer layers""" + x_size = (self.h, self.w) + + # Patch embedding + x = self.patch_embed(x) + + # Add absolute position embedding if enabled + if self.ape and hasattr(self.parameters, "absolute_pos_embed"): + x = ttnn.add(x, self.parameters.absolute_pos_embed, memory_config=self.memory_config, dtype=self.dtype) + # Apply transformer layers + x = ttnn.reshape(x, [x.shape[0], x.shape[1] * x.shape[2], x.shape[3]]) + for i in range(self.num_layers): + x = self.layers[i](x, x_size, self.parameters["forward_params"]) + ttnn.ReadDeviceProfiler(self.device) + + # Layer normalization + x = ttnn.layer_norm( + x, + weight=self.parameters.norm.weight, + bias=self.parameters.norm.bias, + memory_config=self.memory_config, + ) + + # Patch unembedding + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + """Main forward pass""" + # Normalize input + x = ttnn.subtract(x, self.mean, memory_config=self.memory_config) + x = ttnn.multiply(x, self.img_range, memory_config=self.memory_config) + + if self.upsampler == "pixelshuffle": + # Shallow feature extraction + x = ttnn.conv2d( + x, + self.parameters.conv_first.weight, + bias=self.parameters.conv_first.bias, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + device=self.device, + memory_config=self.memory_config, + ) + + # Deep feature extraction with residual connection + features = self.forward_features(x) + # Residual connection after body + if hasattr(self.parameters, "conv_after_body"): + features = ttnn.conv2d( + features, + self.parameters.conv_after_body.weight, + bias=self.parameters.conv_after_body.bias, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + device=self.device, + memory_config=self.memory_config, + ) + + x = ttnn.add(x, features, memory_config=self.memory_config) + + # Pre-upsample convolution + x = ttnn.conv2d( + x, + self.parameters.conv_before_upsample.weight, + bias=self.parameters.conv_before_upsample.bias, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + device=self.device, + memory_config=self.memory_config, + ) + + # LeakyReLU activation + x = ttnn.leaky_relu(x, negative_slope=0.01) + + # Upsampling + x = self.parameters.upsample(x) + + # Final convolution + x = ttnn.conv2d( + x, + self.parameters.conv_last.weight, + bias=self.parameters.conv_last.bias, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + device=self.device, + memory_config=self.memory_config, + ) + + # Denormalize output + x = ttnn.divide(x, self.img_range, memory_config=self.memory_config) + x = ttnn.add(x, self.mean, memory_config=self.memory_config) + + return x + + +class TTTileRefinement(TTHAT): + """TTNN Tile Refinement Module + + Outputs both feature and final upsampled image, following the same pattern + as the PyTorch TileRefinement class. + """ + + def forward(self, x): + """Forward pass that returns both output and features""" + # Normalize input + batch_size = x.shape[0] + self.mean = ttnn.to_layout(self.mean, ttnn.TILE_LAYOUT) + x = ttnn.subtract(x, self.mean, memory_config=self.memory_config, dtype=self.dtype) + x = ttnn.multiply(x, self.img_range, memory_config=self.memory_config, dtype=self.dtype) + + if self.upsampler == "pixelshuffle": + # Shallow feature extraction + x = ttnn.permute(x, (0, 2, 3, 1)) # (batch_size, embed_dim, num_patches) + self.conv_config = ttnn.Conv2dConfig( + weights_dtype=ttnn.bfloat16, + activation="", + output_layout=ttnn.TILE_LAYOUT, + deallocate_activation=True, # Free input memory after use + reallocate_halo_output=True, # Reduce memory fragmentation + act_block_h_override=32, # Use smaller activation blocks + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, # Use height sharding + ) + self.compute_config = ttnn.init_device_compute_kernel_config( + self.device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + x = ttnn.conv2d( + input_tensor=x, + weight_tensor=self.parameters["conv_first"]["weight"], + bias_tensor=self.parameters["conv_first"]["bias"], + in_channels=3, + out_channels=180, + device=self.device, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + batch_size=x.shape[0], + input_height=x.shape[1], + input_width=x.shape[2], + conv_config=self.conv_config, + compute_config=self.compute_config, + return_output_dim=False, + return_weights_and_bias=False, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + # slice_config=slice_config, + ) + x = ttnn.reshape(x, [batch_size, 64, 64, 180]) + x = ttnn.permute(x, (0, 3, 1, 2)) + + # Deep feature extraction - store as fea + fea = self.forward_features(x) + + # Residual connection after body (using fea, not x like in HAT) + self.conv_afterbody_config = ttnn.Conv2dConfig( + weights_dtype=ttnn.bfloat16, + activation="", + output_layout=ttnn.TILE_LAYOUT, + deallocate_activation=False, # Free input memory after use + reallocate_halo_output=True, # Reduce memory fragmentation + act_block_h_override=32, # Use smaller activation blocks + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, # Use height sharding + ) + fea = ttnn.permute(fea, (0, 2, 3, 1)) + x_after_body = ttnn.conv2d( + input_tensor=fea, + weight_tensor=self.parameters["conv_after_body"]["weight"], + bias_tensor=self.parameters["conv_after_body"]["bias"], + in_channels=180, + out_channels=180, + device=self.device, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + batch_size=fea.shape[0], + input_height=fea.shape[1], + input_width=fea.shape[2], + conv_config=self.conv_afterbody_config, + compute_config=self.compute_config, + return_output_dim=False, + return_weights_and_bias=False, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + x_after_body = ttnn.reshape(x_after_body, [batch_size, 64, 64, 180]) + x = ttnn.permute(x, (0, 2, 3, 1)) + x = ttnn.add(x, x_after_body, memory_config=self.memory_config, dtype=self.dtype) + + # Pre-upsample convolution + x = ttnn.conv2d( + input_tensor=x, + weight_tensor=self.parameters.conv_before_upsample.weight, + bias_tensor=self.parameters.conv_before_upsample.bias, + in_channels=180, + out_channels=64, + device=self.device, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + batch_size=x.shape[0], + input_height=x.shape[1], + input_width=x.shape[2], + conv_config=self.conv_config, + compute_config=self.compute_config, + memory_config=self.memory_config, + dtype=ttnn.bfloat16, + return_weights_and_bias=False, + ) + + # LeakyReLU activation + x = ttnn.leaky_relu(x, negative_slope=0.01, memory_config=ttnn.DRAM_MEMORY_CONFIG) + x = ttnn.reshape(x, [batch_size, 64, 64, 64]) + + # Upsampling + x = self.upsample(x, self.parameters["upsample"]) + + # Final convolution + x = ttnn.conv2d( + input_tensor=x, + weight_tensor=self.parameters.conv_last.weight, + bias_tensor=self.parameters.conv_last.bias, + in_channels=64, + out_channels=3, + device=self.device, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + batch_size=x.shape[0], + input_height=x.shape[1], + input_width=x.shape[2], + conv_config=self.conv_config, + compute_config=self.compute_config, + memory_config=self.memory_config, + dtype=ttnn.bfloat16, + return_weights_and_bias=False, + ) + + x = ttnn.reshape(x, [batch_size, 256, 256, 3]) + # Denormalize output + x = ttnn.divide(x, self.img_range, memory_config=self.memory_config, dtype=self.dtype) + self.mean = ttnn.permute(self.mean, (0, 2, 3, 1)) + x = ttnn.add(x, self.mean, memory_config=self.memory_config, dtype=self.dtype) + self.mean = ttnn.permute(self.mean, (0, 3, 1, 2)) + + return x, fea diff --git a/models/experimental/SSR/tt/tile_refinement/upsample.py b/models/experimental/SSR/tt/tile_refinement/upsample.py new file mode 100644 index 000000000000..a30ec792f4d1 --- /dev/null +++ b/models/experimental/SSR/tt/tile_refinement/upsample.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import math +from models.common.lightweightmodule import LightweightModule +import torch + + +class TTUpsample(LightweightModule): + def __init__(self, scale, num_feat, device, dtype=ttnn.bfloat16): + self.scale = scale + self.num_feat = num_feat + self.device = device + self.memory_config = ttnn.DRAM_MEMORY_CONFIG + self.dtype = dtype + + # Pre-calculate operation parameters + if (scale & (scale - 1)) == 0: # scale = 2^n + self.num_ops = int(math.log(scale, 2)) + self.out_channels = 4 * num_feat + self.scale_factor = 2 + elif scale == 3: + self.num_ops = 1 + self.out_channels = 9 * num_feat + self.scale_factor = 3 + else: + raise ValueError(f"Unsupported scale: {scale}") + # Initialize compute config for the device + self.compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + # Initialize conv config with no activation and default output layout + self.conv_config = ttnn.Conv2dConfig( + weights_dtype=ttnn.bfloat16, + activation="", + output_layout=ttnn.TILE_LAYOUT, + deallocate_activation=True, # Free input memory after use + reallocate_halo_output=False, # Reduce memory fragmentation + act_block_h_override=32, # Use smaller activation blocks + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, # Use height sharding + ) + + def pixel_shuffle_torch(self, x, upscale_factor): + """Implement PixelShuffle operation using PyTorch for better performance""" + # PyTorch pixel_shuffle expects NCHW format, but our tensor is NHWC + # Convert from NHWC to NCHW + torch_tensor = x.permute(0, 3, 1, 2) + + torch_output = torch.nn.functional.pixel_shuffle(torch_tensor, upscale_factor) + + # Convert back from NCHW to NHWC + torch_output = torch_output.permute(0, 2, 3, 1) + + ttnn_output = ttnn.from_torch( + torch_output, + device=self.device, + dtype=self.dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=self.memory_config, + ) + + return ttnn_output + + def forward(self, x, parameters): + current = x + current_channels = self.num_feat + slice_config = ttnn.Conv2dSliceConfig(slice_type=ttnn.Conv2dSliceHeight, num_slices=4) + for i in range(self.num_ops): + # Calculate output channels for this specific convolution + out_channels = current_channels * (self.scale_factor * self.scale_factor) + batch_size = current.shape[0] + height = current.shape[1] + width = current.shape[2] + current = ttnn.conv2d( + input_tensor=current, + weight_tensor=parameters[f"conv_{i}"]["weight"], + bias_tensor=parameters[f"conv_{i}"]["bias"] if parameters[f"conv_{i}"]["bias"] else None, + in_channels=current_channels, + out_channels=out_channels, + device=self.device, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + batch_size=batch_size, + input_height=height, + input_width=width, + conv_config=self.conv_config, + compute_config=self.compute_config, + dtype=self.dtype, + return_output_dim=False, + return_weights_and_bias=False, + slice_config=slice_config, + ) + + current = ttnn.to_torch(current) + # reshape B,1,H*W, C to B, H, W, C + current = current.reshape( + batch_size, + current.shape[2] // (height * batch_size), + current.shape[2] // (height * batch_size), + out_channels, + ) + current = self.pixel_shuffle_torch(current, self.scale_factor) + # After pixel shuffle, channels return to original count + current_channels = self.num_feat + + return current diff --git a/models/experimental/SSR/tt/tile_selection/__init__.py b/models/experimental/SSR/tt/tile_selection/__init__.py new file mode 100644 index 000000000000..8ec5dbff4285 --- /dev/null +++ b/models/experimental/SSR/tt/tile_selection/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from .tile_selection import TTTileSelection +from .patch_embed import TTPatchEmbed +from .mask_token_inference import TTMaskTokenInference +from .basic_layer import TTBasicLayer, TTPatchMerging, TTSwinTransformerBlock, TTWindowAttention + +__all__ = [ + "TTTileSelection", + "TTPatchEmbed", + "TTMaskTokenInference", + "TTBasicLayer", + "TTPatchMerging", + "TTSwinTransformerBlock", + "TTWindowAttention", +] diff --git a/models/experimental/SSR/tt/tile_selection/basic_layer/__init__.py b/models/experimental/SSR/tt/tile_selection/basic_layer/__init__.py new file mode 100644 index 000000000000..9498e990cb54 --- /dev/null +++ b/models/experimental/SSR/tt/tile_selection/basic_layer/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from .basic_block import TTBasicLayer +from .patch_merging import TTPatchMerging +from .swin_transformer_block import TTSwinTransformerBlock, TTWindowAttention + +__all__ = ["TTBasicLayer", "TTPatchMerging", "TTSwinTransformerBlock", "TTWindowAttention"] diff --git a/models/experimental/SSR/tt/tile_selection/basic_layer/basic_block.py b/models/experimental/SSR/tt/tile_selection/basic_layer/basic_block.py new file mode 100644 index 000000000000..82195c7078ab --- /dev/null +++ b/models/experimental/SSR/tt/tile_selection/basic_layer/basic_block.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule + +from .swin_transformer_block import TTSwinTransformerBlock +from .patch_merging import TTPatchMerging + + +class TTBasicLayer(LightweightModule): + def __init__( + self, + device, + parameters, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + downsample=None, + memory_config=None, + dtype=ttnn.bfloat16, + ): + super().__init__() + self.device = device + self.memory_config = ttnn.DRAM_MEMORY_CONFIG + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.dtype = dtype + + # Build transformer blocks + self.blocks = [] + for i in range(depth): + shift_size = 0 if (i % 2 == 0) else window_size // 2 + + block = TTSwinTransformerBlock( + device=device, + parameters=parameters["blocks"][i], + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=shift_size, + mlp_ratio=mlp_ratio, + dtype=self.dtype, + ) + self.blocks.append(block) + + # Optional downsampling layer + self.has_downsample = downsample is not None + if self.has_downsample: + self.downsample = TTPatchMerging( + device=device, + parameters=parameters["downsample"], + input_resolution=input_resolution, + dim=dim, + memory_config=memory_config, + dtype=self.dtype, + ) + + def forward(self, input_tensor): + # Process through all transformer blocks + x = input_tensor + for block in self.blocks: + x = block(x) + + # Apply downsampling if present + if self.has_downsample: + x = self.downsample(x) + + return x diff --git a/models/experimental/SSR/tt/tile_selection/basic_layer/patch_merging.py b/models/experimental/SSR/tt/tile_selection/basic_layer/patch_merging.py new file mode 100644 index 000000000000..c0233b9a2ea9 --- /dev/null +++ b/models/experimental/SSR/tt/tile_selection/basic_layer/patch_merging.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.demos.deepseek_v3.utils.config_helpers import matmul_config + + +class TTPatchMerging(LightweightModule): + def __init__( + self, + device, + parameters, + input_resolution, + dim, + memory_config=None, + dtype=ttnn.bfloat8_b, + ): + super().__init__() + self.device = device + self.memory_config = memory_config or ttnn.DRAM_MEMORY_CONFIG + self.input_resolution = input_resolution + self.dim = dim + self.dtype = dtype + + # Extract weights from preprocessed parameters + self.reduction_weight = parameters["reduction"]["weight"] + self.norm_weight = parameters["norm"]["weight"] + self.norm_bias = parameters["norm"]["bias"] + + self.kernel_top_left = parameters["conv_kernels"]["top_left"] + self.kernel_bottom_left = parameters["conv_kernels"]["bottom_left"] + self.kernel_top_right = parameters["conv_kernels"]["top_right"] + self.kernel_bottom_right = parameters["conv_kernels"]["bottom_right"] + + def forward(self, input_tensor): + """ + Args: + input_tensor: TTNN tensor with shape [B, H*W, C] + Returns: + TTNN tensor with shape [B, H/2*W/2, 2*C] + """ + H, W = self.input_resolution + B, L, C = input_tensor.shape + + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + # Reshape to spatial dimensions [B, H, W, C] + input_tensor = ttnn.reshape(input_tensor, (B, H, W, C)) + x = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.dtype) + + # Common convolution parameters + conv_params = { + "input_tensor": x, + "in_channels": C, + "out_channels": C, + "device": self.device, + "kernel_size": (2, 2), + "stride": (2, 2), + "padding": (0, 0), + "groups": C, # Grouped convolution + "batch_size": B, + "input_height": H, + "input_width": W, + "conv_config": None, + "dtype": self.dtype, + "memory_config": ttnn.DRAM_MEMORY_CONFIG, + } + + # Apply grouped convolutions for each patch, this is instead of a slice operation + x0 = ttnn.conv2d(weight_tensor=self.kernel_top_left, **conv_params) + x1 = ttnn.conv2d(weight_tensor=self.kernel_bottom_left, **conv_params) + x2 = ttnn.conv2d(weight_tensor=self.kernel_top_right, **conv_params) + x3 = ttnn.conv2d(weight_tensor=self.kernel_bottom_right, **conv_params) + + ttnn.deallocate(x) + + # Concatenate along channel dimension [B, H/2, W/2, 4*C] + merged = ttnn.concat([x0, x1, x2, x3], dim=-1, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + # Clean up intermediate tensors + ttnn.deallocate(x0) + ttnn.deallocate(x1) + ttnn.deallocate(x2) + ttnn.deallocate(x3) + + # Reshape to sequence format [B, H/2*W/2, 4*C] + merged = ttnn.reshape(merged, (B, (H // 2) * (W // 2), 4 * C), memory_config=ttnn.L1_MEMORY_CONFIG) + + # Apply layer normalization + normalized = ttnn.layer_norm( + merged, + weight=self.norm_weight, + bias=self.norm_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + ttnn.deallocate(merged) + + # Apply linear reduction [B, H/2*W/2, 2*C] + output_matmul_config = matmul_config( + normalized.shape[-2], normalized.shape[-1], self.reduction_weight.shape[-2], (8, 8) + ) + output = ttnn.linear( + normalized, + self.reduction_weight, + memory_config=ttnn.L1_MEMORY_CONFIG, + program_config=output_matmul_config, + dtype=self.dtype, + ) + ttnn.deallocate(normalized) + + return output diff --git a/models/experimental/SSR/tt/tile_selection/basic_layer/swin_transformer_block/__init__.py b/models/experimental/SSR/tt/tile_selection/basic_layer/swin_transformer_block/__init__.py new file mode 100644 index 000000000000..632539c42ab3 --- /dev/null +++ b/models/experimental/SSR/tt/tile_selection/basic_layer/swin_transformer_block/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +from .swin_transformer_block import TTSwinTransformerBlock +from .window_attn import TTWindowAttention + +__all__ = ["TTSwinTransformerBlock", "TTWindowAttention"] diff --git a/models/experimental/SSR/tt/tile_selection/basic_layer/swin_transformer_block/swin_transformer_block.py b/models/experimental/SSR/tt/tile_selection/basic_layer/swin_transformer_block/swin_transformer_block.py new file mode 100644 index 000000000000..b8eaca64495e --- /dev/null +++ b/models/experimental/SSR/tt/tile_selection/basic_layer/swin_transformer_block/swin_transformer_block.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +from models.common.lightweightmodule import LightweightModule + +from .window_attn import TTWindowAttention +from models.experimental.SSR.tt.common import TTMlp + + +class TTSwinTransformerBlock(LightweightModule): + def __init__( + self, + parameters, + device, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + dtype=ttnn.bfloat16, + ): + super().__init__() + self.parameters = parameters + self.device = device + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.memory_config = ttnn.L1_MEMORY_CONFIG + self.dtype = dtype + + if min(self.input_resolution) <= self.window_size: + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + # Initialize attention module + self.attn = TTWindowAttention( + parameters["attn"], + device, + dim, + window_size, + num_heads, + dtype=self.dtype, + ) + + # Initialize MLP + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = TTMlp( + device=device, + in_features=dim, + hidden_features=mlp_hidden_dim, + out_features=dim, + parameters=parameters["mlp"], + dtype=self.dtype, + ) + + def _compute_attention_mask(self): + """Compute attention mask for shifted window attention""" + H, W = self.input_resolution + + # Create mask on CPU first + img_mask = torch.zeros((1, H, W, 1)) + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # Partition into windows + mask_windows, _ = self._window_partition_padding(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + # Convert to TTNN tensor + return ttnn.from_torch( + attn_mask, + device=self.device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=ttnn.bfloat4_b, + ) + + def _window_partition_padding(self, x, window_size): + """Partition into non-overlapping windows with padding if needed""" + if isinstance(x, torch.Tensor): + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = torch.nn.functional.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + else: + # TTNN tensor case + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = ttnn.pad(x, ((0, 0), (0, pad_h), (0, pad_w), (0, 0)), 0.0) + Hp, Wp = H + pad_h, W + pad_w + + x = ttnn.reshape( + x, + (B, Hp // window_size, window_size, Wp // window_size, window_size, C), + memory_config=self.memory_config, + ) + x = ttnn.permute(x, (0, 1, 3, 2, 4, 5), memory_config=self.memory_config) + x = ttnn.reshape(x, (-1, window_size, window_size, C), memory_config=self.memory_config) + return x, (Hp, Wp) + + def _window_unpartition(self, x, window_size, pad_hw, hw): + """Window unpartition into original sequences and removing padding""" + Hp, Wp = pad_hw + H, W = hw + B = x.shape[0] // (Hp * Wp // window_size // window_size) + + x = ttnn.reshape( + x, (B, Hp // window_size, Wp // window_size, window_size, window_size, -1), memory_config=self.memory_config + ) + x = ttnn.permute(x, (0, 1, 3, 2, 4, 5), memory_config=self.memory_config) + x = ttnn.reshape(x, (B, Hp, Wp, -1), memory_config=self.memory_config) + + if Hp > H or Wp > W: + x = ttnn.slice(x, [0, 0, 0, 0], [x.shape[0], H, W, x.shape[3]], memory_config=self.memory_config) + return x + + def forward(self, input_tensor): + H, W = self.input_resolution + B, L, C = input_tensor.shape + + # Store shortcut connection + shortcut = input_tensor + shortcut = ttnn.reallocate(shortcut, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + # Layer normalization 1 + norm1_weight = self.parameters["norm1"]["weight"] + norm1_bias = self.parameters["norm1"]["bias"] + x = ttnn.layer_norm(input_tensor, weight=norm1_weight, bias=norm1_bias, memory_config=self.memory_config) + + # Reshape to spatial format + x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=self.memory_config, dtype=self.dtype) + x = ttnn.reshape(x, (B, H, W, C)) + + # Cyclic shift + if self.shift_size > 0: + # TTNN doesn't have direct roll operation, so we implement it with slicing and concatenation + x = ttnn.roll(x, [-self.shift_size, -self.shift_size], [1, 2]) + + # Partition windows + x, pad_hw = self._window_partition_padding(x, self.window_size) + x = ttnn.reshape(x, (-1, self.window_size * self.window_size, C), memory_config=self.memory_config) + x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT, memory_config=self.memory_config, dtype=self.dtype) + + # Pre-compute attention mask if needed + if self.shift_size > 0: + self.attn_mask = self._compute_attention_mask() + else: + self.attn_mask = None + + # Window attention + x = self.attn(x, mask=self.attn_mask) + + if self.attn_mask is not None: + ttnn.deallocate(self.attn_mask) + + # Merge windows + x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=self.memory_config, dtype=self.dtype) + x = ttnn.reshape(x, (-1, self.window_size, self.window_size, C), memory_config=self.memory_config) + x = self._window_unpartition(x, self.window_size, pad_hw, (H, W)) + + # Reverse cyclic shift + if self.shift_size > 0: + x = ttnn.roll(x, [self.shift_size, self.shift_size], [1, 2]) + + # Reshape back to sequence format + x = ttnn.reshape(x, (B, H * W, C)) + x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT, memory_config=self.memory_config, dtype=self.dtype) + + # First residual connection (no drop_path implementation in TTNN) + x = ttnn.add(shortcut, x, memory_config=self.memory_config, dtype=self.dtype) + + residual = ttnn.reallocate(x, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + # Layer normalization 2 + norm2_weight = self.parameters["norm2"]["weight"] + norm2_bias = self.parameters["norm2"]["bias"] + x = ttnn.layer_norm(x, weight=norm2_weight, bias=norm2_bias, memory_config=self.memory_config) + + # MLP + x = self.mlp(x) + + # Second residual connection + x = ttnn.add(residual, x, memory_config=self.memory_config, dtype=self.dtype) + + return x diff --git a/models/experimental/SSR/tt/tile_selection/basic_layer/swin_transformer_block/window_attn.py b/models/experimental/SSR/tt/tile_selection/basic_layer/swin_transformer_block/window_attn.py new file mode 100644 index 000000000000..52d0090ae4ed --- /dev/null +++ b/models/experimental/SSR/tt/tile_selection/basic_layer/swin_transformer_block/window_attn.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch.nn as nn + + +class TTWindowAttention(nn.Module): + def __init__( + self, + parameters, + device, + dim, + window_size, + num_heads, + dtype=ttnn.bfloat16, + ): + super().__init__() + self.parameters = parameters + self.device = device + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.dtype = dtype + + def forward(self, input_tensor, mask=None): + """ + Args: + input_tensor: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + relative_position_bias = self.parameters["relative_position_bias"] + + B_, N, C = input_tensor.shape + + # QKV projection + qkv_weight = self.parameters["qkv"]["weight"] + qkv_bias = self.parameters["qkv"]["bias"] + + qkv = ttnn.linear( + input_tensor, + qkv_weight, + bias=qkv_bias, + compute_kernel_config=ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + ), + memory_config=ttnn.L1_MEMORY_CONFIG if B_ * N * C < 1_100_000 else ttnn.DRAM_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(x=8, y=8), + dtype=self.dtype, + ) + ttnn.deallocate(input_tensor) + + # Split QKV using built-in function + ( + q, + k, + v, + ) = ttnn.transformer.split_query_key_value_and_split_heads( + qkv, memory_config=ttnn.DRAM_MEMORY_CONFIG, num_heads=self.num_heads, transpose_key=True + ) + ttnn.deallocate(qkv) + + # Apply scaling + q = q * self.scale + + # Compute attention scores + attn = ttnn.matmul( + q, + k, + compute_kernel_config=ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + ), + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=self.dtype, + ) + + # Clean up intermediate tensors + ttnn.deallocate(q) + ttnn.deallocate(k) + + # Add relative position bias + attn = ttnn.add(attn, relative_position_bias, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.dtype) + + # Apply mask if provided + if mask is not None: + nW = mask.shape[0] + attn = ttnn.reshape(attn, (B_ // nW, nW, self.num_heads, N, N), memory_config=ttnn.L1_MEMORY_CONFIG) + attn = attn + ttnn.unsqueeze(ttnn.unsqueeze(mask, 1), 0) + attn = ttnn.reshape(attn, (-1, self.num_heads, N, N), memory_config=ttnn.L1_MEMORY_CONFIG) + + ttnn.deallocate(mask) + + # Apply softmax + attn = ttnn.softmax(attn, dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) + + # Compute final output + output_tensor = ttnn.matmul( + attn, + v, + compute_kernel_config=ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + ), + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=self.dtype, + ) + + # Clean up attention and value tensors + ttnn.deallocate(v) + ttnn.deallocate(attn) + + output_tensor = ttnn.transformer.concatenate_heads(output_tensor, memory_config=ttnn.L1_MEMORY_CONFIG) + + # Apply projection + proj_weight = self.parameters["proj"]["weight"] + proj_bias = self.parameters["proj"]["bias"] + + output_tensor = ttnn.linear( + output_tensor, + proj_weight, + bias=proj_bias, + compute_kernel_config=ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + ), + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(x=8, y=8), + dtype=self.dtype, + ) + + return output_tensor diff --git a/models/experimental/SSR/tt/tile_selection/mask_token_inference.py b/models/experimental/SSR/tt/tile_selection/mask_token_inference.py new file mode 100644 index 000000000000..60c79115136a --- /dev/null +++ b/models/experimental/SSR/tt/tile_selection/mask_token_inference.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import torch +from models.common.lightweightmodule import LightweightModule +from models.demos.deepseek_v3.utils.config_helpers import matmul_config + + +class TTMaskTokenInference(LightweightModule): + def __init__( + self, + device, + parameters, + dim, + num_heads=1, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + dtype=ttnn.bfloat16, + ): + self.device = device + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = qk_scale or (self.head_dim**-0.5) + self.dtype = dtype + + # Layer norm parameters + self.norm_weight = parameters["norm"]["weight"] # ttnn tensor for layer norm weight + self.norm_bias = parameters["norm"]["bias"] # ttnn tensor for layer norm bias + + # Linear layer weights + self.proj_weight = parameters["proj"]["weight"] # ttnn tensor for output projection + + self.proj_bias = parameters["proj"]["bias"] + + self.qkv_weight = parameters["qkv"]["weight"] # Pre-fused QKV weight tensor + self.qkv_bias = parameters["qkv"]["bias"] if qkv_bias else None + + # Scale tensor + scale_tensor = torch.tensor(self.scale).view(1, 1, 1, 1) + self.tt_scale = ttnn.from_torch(scale_tensor, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + + def __call__(self, fea): + B, N, C = fea.shape + + # Layer normalization + x = ttnn.layer_norm(fea, weight=self.norm_weight, bias=self.norm_bias, memory_config=ttnn.L1_MEMORY_CONFIG) + fea_skip = fea + fea_skip = ttnn.reallocate(fea_skip, memory_config=ttnn.DRAM_MEMORY_CONFIG) + ttnn.deallocate(fea) + + # Split into classification token and feature tokens + # T_s: classification token [B, 1, C] + # F_s: feature tokens [B, N-1, C] + T_s = ttnn.slice(x, [0, 0, 0], [B, 1, C]) + F_s = ttnn.slice(x, [0, 1, 0], [B, N, C]) + ttnn.deallocate(x) + + # Query from feature tokens + + F_s_qkv = ttnn.linear( + F_s, + self.qkv_weight, + bias=self.qkv_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(x=8, y=8), + dtype=self.dtype, + ) + ttnn.deallocate(F_s) + + # Key from classification token + # For classification token (keys and values from T_s) + T_s_qkv = ttnn.linear( + T_s, + self.qkv_weight, + bias=self.qkv_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(x=8, y=8), + dtype=self.dtype, + ) + ttnn.deallocate(T_s) + + # Split F_s QKV (for queries) + (q_from_F_s, k_from_FS, v_from_FS) = ttnn.transformer.split_query_key_value_and_split_heads( + F_s_qkv, + num_heads=self.num_heads, + transpose_key=True, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + ttnn.deallocate(F_s_qkv) + ttnn.deallocate(k_from_FS) + ttnn.deallocate(v_from_FS) + + (q_from_T_s, k_from_T_s, v_from_T_s) = ttnn.transformer.split_query_key_value_and_split_heads( + T_s_qkv, + num_heads=self.num_heads, + transpose_key=True, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + ttnn.deallocate(T_s_qkv) + ttnn.deallocate(q_from_T_s) + + # Attention computation: q @ k.T + prg_config = matmul_config(q_from_F_s.shape[-2], q_from_F_s.shape[-1], k_from_T_s.shape[-1], (8, 8)) + attn = ttnn.matmul( + q_from_F_s, k_from_T_s, memory_config=ttnn.L1_MEMORY_CONFIG, program_config=prg_config, dtype=self.dtype + ) + ttnn.deallocate(q_from_F_s) + ttnn.deallocate(k_from_T_s) + + # Scale attention scores + attn = ttnn.multiply(attn, self.tt_scale, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.dtype) + + # Apply sigmoid instead of softmax + attn = ttnn.sigmoid(attn, memory_config=ttnn.L1_MEMORY_CONFIG) + + # Compute attention output + infer_fea = ttnn.matmul( + attn, v_from_T_s, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=self.dtype, core_grid=ttnn.CoreGrid(x=8, y=8) + ) + ttnn.deallocate(attn) + ttnn.deallocate(v_from_T_s) + ttnn.reallocate(infer_fea) + + # Reshape back to [B, N-1, C] + infer_fea = ttnn.permute(infer_fea, (0, 2, 1, 3)) + infer_fea = ttnn.reshape(infer_fea, (B, N - 1, C)) + + # Output projection + infer_fea = ttnn.linear( + infer_fea, + self.proj_weight, + bias=self.proj_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(x=8, y=8), + ) + + # Apply projection dropout (if needed) + + # Residual connection with original feature tokens + original_features = ttnn.slice(fea_skip, [0, 1, 0], [B, N, C]) + infer_fea = ttnn.add(infer_fea, original_features, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.dtype) + + return infer_fea diff --git a/models/experimental/SSR/tt/tile_selection/patch_embed.py b/models/experimental/SSR/tt/tile_selection/patch_embed.py new file mode 100644 index 000000000000..3f818a0f6409 --- /dev/null +++ b/models/experimental/SSR/tt/tile_selection/patch_embed.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class TTPatchEmbed(LightweightModule): + """TTNN Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + device: TTNN device + dtype: TTNN data type. Default: ttnn.bfloat16 + """ + + def __init__( + self, + img_size: int = 224, + patch_size: int = 4, + in_chans: int = 3, + embed_dim: int = 96, + device=None, + dtype=ttnn.bfloat16, + parameters=None, + memory_config=None, + ): + # Convert to tuples (assuming square images/patches for simplicity) + self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size + self.patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size + + self.patches_resolution = [self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1]] + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + self.device = device + self.dtype = dtype + self.memory_config = ttnn.L1_MEMORY_CONFIG + # Store projection parameters (weight and bias) + self.proj_weight = parameters["proj"]["weight"] + self.proj_bias = parameters["proj"]["bias"] + + # Initialize compute config for the device + self.compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + # Initialize conv config with no activation and default output layout + self.conv_config = ttnn.Conv2dConfig( + weights_dtype=self.dtype, + activation="", + output_layout=ttnn.TILE_LAYOUT, + deallocate_activation=True, # Free input memory after use + reallocate_halo_output=True, # Reduce memory fragmentation + act_block_h_override=64, # Use smaller activation blocks + ) + + def forward(self, x): + batch_size, img_h, img_w, _ = x.shape # NHWC format + + # Use DRAM slicing for large inputs + slice_config = ttnn.Conv2dSliceConfig(slice_type=ttnn.Conv2dSliceHeight, num_slices=6) + # Validate input dimensions + assert ( + img_h == self.img_size[0] and img_w == self.img_size[1] + ), f"Input image size ({img_h}*{img_w}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})" + + output, (out_height, out_width) = ttnn.conv2d( + input_tensor=x, + weight_tensor=self.proj_weight, + bias_tensor=self.proj_bias, + in_channels=self.in_chans, + out_channels=self.embed_dim, + device=self.device, + kernel_size=self.patch_size, + stride=self.patch_size, + padding=(0, 0), # Simplest case: no padding + batch_size=batch_size, + input_height=img_h, + input_width=img_w, + conv_config=self.conv_config, + compute_config=self.compute_config, + return_output_dim=True, # Only return the output tensor for simplest call + return_weights_and_bias=False, # Weights and bias are already prepared + dtype=self.dtype, # Specify output dtype + slice_config=slice_config, + ) + flattened_output = ttnn.reshape(output, (batch_size, out_height * out_width, self.embed_dim)) + + return flattened_output diff --git a/models/experimental/SSR/tt/tile_selection/tile_selection.py b/models/experimental/SSR/tt/tile_selection/tile_selection.py new file mode 100644 index 000000000000..7e69eddf78ae --- /dev/null +++ b/models/experimental/SSR/tt/tile_selection/tile_selection.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import math +from models.common.lightweightmodule import LightweightModule + +from .patch_embed import TTPatchEmbed +from .basic_layer import TTBasicLayer +from .mask_token_inference import TTMaskTokenInference +from ..common import TTMlp + + +class TTTileSelection(LightweightModule): + def __init__(self, device, parameters, args, num_cls, memory_config=None, dtype=ttnn.bfloat16): + super().__init__() + self.device = device + self.token_size = args.token_size + self.num_layers = int(math.log2((args.imgsz // args.patchsz) // args.token_size)) + self.num_cls = num_cls + self.memory_config = ttnn.DRAM_MEMORY_CONFIG + self.dtype = dtype + + # Initialize patch embedding using existing TTPatchEmbed + self.patch_embed = TTPatchEmbed( + img_size=args.imgsz, + patch_size=args.patchsz, + in_chans=3, + embed_dim=args.dim, + device=device, + parameters=parameters["patch_embed"], + memory_config=memory_config, + dtype=dtype, + ) + + # Initialize encoder layers using existing TTBasicLayer + self.layers = [] + patches_resolution = (args.imgsz // args.patchsz, args.imgsz // args.patchsz) + + for i_layer in range(self.num_layers): + layer = TTBasicLayer( + device=device, + parameters=parameters[f"layers.{i_layer}"], + dim=int(args.dim * 2**i_layer), + input_resolution=(patches_resolution[0] // (2**i_layer), patches_resolution[1] // (2**i_layer)), + depth=2, + num_heads=3, + window_size=7, + mlp_ratio=4.0, + downsample=True if i_layer < self.num_layers - 1 else False, + memory_config=memory_config, + dtype=dtype, + ) + self.layers.append(layer) + + # Layer norm parameters for different scales + self.norm3_weight = parameters["norm3"]["weight"] + self.norm3_bias = parameters["norm3"]["bias"] + self.norm2_weight = parameters["norm2"]["weight"] + self.norm2_bias = parameters["norm2"]["bias"] + self.norm1_weight = parameters["norm1"]["weight"] + self.norm1_bias = parameters["norm1"]["bias"] + + # Mask token embedding + self.mask_token_weight = parameters["mask_token"]["weight"] + + # Initialize MLPs using existing TTMlp + final_dim = 96 * (2**self.num_layers) + self.fea_mlp3 = TTMlp( + device=device, + in_features=final_dim, + hidden_features=final_dim, + out_features=final_dim, + parameters=parameters["fea_mlp3"], + dtype=dtype, + ) + + # Initialize mask token inference modules + self.mask_pre3 = TTMaskTokenInference( + device=device, parameters=parameters["mask_pre3"], dim=final_dim, num_heads=1 + ) + + # MLP norm parameters + self.mlp_norm3_weight = parameters["mlp_norm3"]["weight"] + self.mlp_norm3_bias = parameters["mlp_norm3"]["bias"] + self.mlp_norm2_weight = parameters["mlp_norm2"]["weight"] + self.mlp_norm2_bias = parameters["mlp_norm2"]["bias"] + self.mlp_norm1_weight = parameters["mlp_norm1"]["weight"] + self.mlp_norm1_bias = parameters["mlp_norm1"]["bias"] + + # Classification MLPs + self.mlp3 = TTMlp( + device=device, + in_features=final_dim, + hidden_features=96, + out_features=96, + parameters=parameters["mlp3"], + dtype=dtype, + ) + + # Linear classification layers + self.linear3_weight = parameters["linear3"]["weight"] + self.linear3_bias = parameters["linear3"]["bias"] + self.linear2_weight = parameters["linear2"]["weight"] + self.linear2_bias = parameters["linear2"]["bias"] + self.linear1_weight = parameters["linear1"]["weight"] + self.linear1_bias = parameters["linear1"]["bias"] + + def forward(self, x): + """ + Args: + x: Input tensor [B, C, H, W] + """ + x = ttnn.permute(x, (0, 2, 3, 1)) + B, C, H, W = x.shape + + # Patch embedding using existing TTPatchEmbed + x = self.patch_embed(x) + + # Encoder using existing TTBasicLayer components + x_downsample = [] + for layer in self.layers: + x_downsample.append(x) + x = layer(x) + + # Apply layer normalization to different scale features + x3 = ttnn.layer_norm(x, weight=self.norm3_weight, bias=self.norm3_bias) + + # Get mask tokens and expand for batch + mask_tokens = ttnn.unsqueeze(self.mask_token_weight, 0) + mask_tokens = ttnn.expand(mask_tokens, [B, -1, -1]) + + # Process scale 3 (finest scale) + fea_3_processed = self.fea_mlp3(x3) + mask_tokens = ttnn.to_layout(mask_tokens, ttnn.ROW_MAJOR_LAYOUT) + fea_3_processed = ttnn.to_layout(fea_3_processed, ttnn.ROW_MAJOR_LAYOUT) + fea_3 = ttnn.concat([mask_tokens, fea_3_processed], dim=1, memory_config=ttnn.DRAM_MEMORY_CONFIG) + fea_3 = ttnn.to_layout(fea_3, ttnn.TILE_LAYOUT) + + mask_tokens = ttnn.slice(fea_3, [0, 0, 0], [B, 1, fea_3.shape[-1]]) + mask_3 = self.mask_pre3(fea_3) + mask_3 = ttnn.layer_norm(mask_3, weight=self.mlp_norm3_weight, bias=self.mlp_norm3_bias) + mask_3 = self.mlp3(mask_3) + mask_3 = ttnn.linear(mask_3, self.linear3_weight, bias=self.linear3_bias) + mask_3 = self._reshape_output(mask_3, B, self.token_size, self.token_size) + + return mask_3 + + def _reshape_output(self, mask_output, B, H, W): + """Reshape output to spatial dimensions""" + # mask_output shape: [B, N, C] where C is num_cls + N, C = mask_output.shape[1], mask_output.shape[2] + + # Transpose and reshape: [B, N, C] -> [B, C, N] -> [B, C, H, W] + mask_output = ttnn.transpose(mask_output, 1, 2) # [B, C, N] + mask_output = ttnn.reshape(mask_output, [B, C, H, W]) + + return mask_output