|
| 1 | +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | +import ttnn |
| 8 | +from loguru import logger |
| 9 | + |
| 10 | +from models.experimental.SSR.tt.tile_refinement import TTCAB |
| 11 | +from tests.ttnn.utils_for_testing import check_with_pcc |
| 12 | +from models.experimental.SSR.reference.SSR.model.tile_refinement import ChannelAttention |
| 13 | +from models.experimental.SSR.tests.tile_refinement.test_channel_attention import create_channel_attention_preprocessor |
| 14 | + |
| 15 | + |
| 16 | +class CAB(nn.Module): |
| 17 | + """Reference PyTorch CAB implementation""" |
| 18 | + |
| 19 | + def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30): |
| 20 | + super(CAB, self).__init__() |
| 21 | + self.cab = nn.Sequential( |
| 22 | + nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1), |
| 23 | + nn.GELU(), |
| 24 | + nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1), |
| 25 | + ChannelAttention(num_feat, squeeze_factor), |
| 26 | + ) |
| 27 | + |
| 28 | + def forward(self, x): |
| 29 | + return self.cab(x) |
| 30 | + |
| 31 | + |
| 32 | +def create_cab_preprocessor(device): |
| 33 | + def custom_preprocessor(torch_model, name, ttnn_module_args): |
| 34 | + params = {} |
| 35 | + |
| 36 | + # Extract the sequential layers from CAB |
| 37 | + cab_layers = list(torch_model.cab.children()) |
| 38 | + conv1 = cab_layers[0] # First Conv2d layer |
| 39 | + conv2 = cab_layers[2] # Second Conv2d layer (after GELU) |
| 40 | + channel_attention = cab_layers[3] # ChannelAttention module |
| 41 | + |
| 42 | + # Preprocess first convolution (3x3) |
| 43 | + params["conv1"] = { |
| 44 | + "weight": ttnn.from_torch(conv1.weight, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT), |
| 45 | + "bias": ttnn.from_torch(conv1.bias.reshape(1, 1, 1, -1), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT), |
| 46 | + } |
| 47 | + |
| 48 | + # Preprocess second convolution (3x3) |
| 49 | + params["conv2"] = { |
| 50 | + "weight": ttnn.from_torch(conv2.weight, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT), |
| 51 | + "bias": ttnn.from_torch(conv2.bias.reshape(1, 1, 1, -1), dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT), |
| 52 | + } |
| 53 | + |
| 54 | + # Preprocess channel attention using existing preprocessor |
| 55 | + channel_attention_preprocessor = create_channel_attention_preprocessor(device) |
| 56 | + params["channel_attention"] = channel_attention_preprocessor( |
| 57 | + channel_attention, "channel_attention", ttnn_module_args |
| 58 | + ) |
| 59 | + |
| 60 | + return params |
| 61 | + |
| 62 | + return custom_preprocessor |
| 63 | + |
| 64 | + |
| 65 | +@pytest.mark.parametrize( |
| 66 | + "batch_size, num_feat, height, width, compress_ratio, squeeze_factor", |
| 67 | + [ |
| 68 | + (1, 180, 64, 64, 3, 30), # SSR config |
| 69 | + ], |
| 70 | +) |
| 71 | +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) |
| 72 | +def test_cab_block(device, batch_size, num_feat, height, width, compress_ratio, squeeze_factor): |
| 73 | + torch.manual_seed(0) |
| 74 | + |
| 75 | + # Create reference model |
| 76 | + ref_model = CAB(num_feat=num_feat, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor) |
| 77 | + ref_model.eval() |
| 78 | + |
| 79 | + # Create input tensor |
| 80 | + input_tensor = torch.randn(batch_size, num_feat, height, width) |
| 81 | + |
| 82 | + # Reference forward pass |
| 83 | + with torch.no_grad(): |
| 84 | + ref_output = ref_model(input_tensor) |
| 85 | + |
| 86 | + parameters = ttnn.model_preprocessing.preprocess_model( |
| 87 | + initialize_model=lambda: ref_model, |
| 88 | + custom_preprocessor=create_cab_preprocessor(device), |
| 89 | + device=device, |
| 90 | + run_model=lambda model: model(input_tensor), |
| 91 | + ) |
| 92 | + |
| 93 | + tt_model = TTCAB( |
| 94 | + device=device, |
| 95 | + parameters=parameters, |
| 96 | + num_feat=num_feat, |
| 97 | + compress_ratio=compress_ratio, |
| 98 | + squeeze_factor=squeeze_factor, |
| 99 | + memory_config=ttnn.DRAM_MEMORY_CONFIG, |
| 100 | + ) |
| 101 | + |
| 102 | + # Convert input to TTNN format (NHWC) |
| 103 | + tt_input = ttnn.from_torch( |
| 104 | + input_tensor.permute(0, 2, 3, 1), |
| 105 | + device=device, |
| 106 | + layout=ttnn.TILE_LAYOUT, |
| 107 | + dtype=ttnn.bfloat16, |
| 108 | + memory_config=ttnn.L1_MEMORY_CONFIG, # NCHW -> NHWC |
| 109 | + ) |
| 110 | + |
| 111 | + # TTNN forward pass |
| 112 | + tt_output = tt_model(tt_input) |
| 113 | + |
| 114 | + # Convert back to PyTorch format |
| 115 | + tt_torch_output = ttnn.to_torch(tt_output) |
| 116 | + tt_torch_output = tt_torch_output.permute(0, 3, 1, 2) # NHWC -> NCHW |
| 117 | + |
| 118 | + # Compare outputs |
| 119 | + does_pass, pcc_message = check_with_pcc(ref_output, tt_torch_output, 0.97) |
| 120 | + |
| 121 | + logger.info(f"Batch: {batch_size}, Features: {num_feat}, Size: {height}x{width}") |
| 122 | + logger.info(f"Compress ratio: {compress_ratio}, Squeeze factor: {squeeze_factor}") |
| 123 | + logger.info(f"Reference output shape: {ref_output.shape}") |
| 124 | + logger.info(f"TTNN output shape: {tt_torch_output.shape}") |
| 125 | + logger.info(pcc_message) |
| 126 | + |
| 127 | + if does_pass: |
| 128 | + logger.info("CAB Block Passed!") |
| 129 | + else: |
| 130 | + logger.warning("CAB Block Failed!") |
| 131 | + |
| 132 | + assert does_pass, f"PCC check failed: {pcc_message}" |
| 133 | + |
| 134 | + # Verify output shapes match |
| 135 | + assert ( |
| 136 | + ref_output.shape == tt_torch_output.shape |
| 137 | + ), f"Shape mismatch: ref {ref_output.shape} vs ttnn {tt_torch_output.shape}" |
0 commit comments