Skip to content

Commit 9970595

Browse files
full SSR, pcc: 0.9999514168874034
Patch Fea3 PCC: 0.9878828566413058 cleanup
1 parent 71d4b93 commit 9970595

File tree

12 files changed

+2
-283
lines changed

12 files changed

+2
-283
lines changed

models/experimental/SSR/tests/test_tile_refinement.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ def test_tile_refinement(
248248
tt_torch_features = tt_torch_features.permute(0, 3, 1, 2)
249249

250250
# Compare outputs
251-
print("Torch OUT: ", ref_output.shape, tt_torch_output.shape, input_shape)
252251
output_pass, output_pcc_message = comp_pcc(ref_output, tt_torch_output, 0.90)
253252
features_pass, features_pcc_message = comp_pcc(ref_features, tt_torch_features, 0.90)
254253

models/experimental/SSR/tests/test_upsample.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def test_upsample(device, scale, num_feat, batch_size, input_size):
6262
# Create test input
6363
torch_input = torch.randn(batch_size, num_feat, input_size, input_size)
6464
torch_output = torch_model(torch_input)
65-
print("LIKEEEEEEEEEE:", torch_output.shape)
6665

6766
# Preprocess model parameters
6867
parameters = preprocess_model_parameters(
@@ -82,7 +81,6 @@ def test_upsample(device, scale, num_feat, batch_size, input_size):
8281
tt_torch_output = tt2torch_tensor(ttnn_output)
8382
tt_torch_output = tt_torch_output.permute(0, 3, 1, 2)
8483

85-
print(torch_output.shape, tt_torch_output.shape)
8684
does_pass, pcc_message = comp_pcc(torch_output, tt_torch_output, 0.99)
8785

8886
logger.info(pcc_message)

models/experimental/SSR/tt/CAB.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,8 @@ class TTCAB(LightweightModule):
1010
def __init__(self, device, parameters, num_feat, compress_ratio=3, squeeze_factor=30, memory_config=None):
1111
super().__init__()
1212

13-
# Debug print for __init__ parameters
14-
# print(f"[TTCAB.__init__] device={device}, num_feat={num_feat}, compress_ratio={compress_ratio}, squeeze_factor={squeeze_factor}, memory_config={memory_config}")
15-
# print(f"[TTCAB.__init__] parameters keys: {list(parameters.keys())}")
16-
# print(f"[TTCAB.__init__] conv1 keys: {list(parameters['conv1'].keys()) if 'conv1' in parameters else 'N/A'}")
17-
# print(f"[TTCAB.__init__] conv2 keys: {list(parameters['conv2'].keys()) if 'conv2' in parameters else 'N/A'}")
18-
# print(f"[TTCAB.__init__] channel_attention keys: {list(parameters['channel_attention'].keys()) if 'channel_attention' in parameters else 'N/A'}")
19-
2013
self.device = device
2114
self.memory_config = ttnn.L1_MEMORY_CONFIG
22-
# self.memory_config = memory_config or ttnn.DRAM_MEMORY_CONFIG
2315
self.num_feat = num_feat
2416
self.compress_ratio = compress_ratio
2517
self.squeeze_factor = squeeze_factor
@@ -40,9 +32,6 @@ def __init__(self, device, parameters, num_feat, compress_ratio=3, squeeze_facto
4032
)
4133

4234
def forward(self, x):
43-
# Debug print for forward input
44-
# print(f"[TTCAB.forward] x.shape={x.shape}")
45-
4635
# Store original input shape for convolutions
4736
batch_size, height, width, channels = x.shape
4837
conv_config = ttnn.Conv2dConfig(
@@ -79,9 +68,6 @@ def forward(self, x):
7968
# Reshape from flattened conv output back to spatial format
8069
x = ttnn.reshape(x, [batch_size, height, width, self.num_feat // self.compress_ratio])
8170

82-
# # GELU activation
83-
# x = ttnn.gelu(x)
84-
8571
conv_config = ttnn.Conv2dConfig(
8672
weights_dtype=ttnn.bfloat16,
8773
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,

models/experimental/SSR/tt/HAB.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def __init__(
6363
)
6464

6565
def forward(self, x, x_size, rpi_sa, attn_mask):
66-
print("HAB")
6766
h, w = x_size
6867
b, seq_len, c = x.shape
6968
if x.memory_config().buffer_type != ttnn.BufferType.L1:

models/experimental/SSR/tt/RHAG.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,9 @@ def forward(self, x, x_size, params):
159159

160160
# Pass through residual group (AttenBlocks)
161161
x = self.residual_group(x, x_size, params)
162-
# return x
163162

164163
# Patch unembed: convert from sequence to spatial format
165164
x = self.patch_unembed(x, x_size)
166-
# return x
167165

168166
# Apply convolutional layer
169167
if self.resi_connection == "1conv":
@@ -194,23 +192,10 @@ def forward(self, x, x_size, params):
194192
elif self.resi_connection == "identity":
195193
x = ttnn.permute(x, (0, 2, 3, 1)) # (batch_size, embed_dim, num_patches)
196194

197-
# x = ttnn.reshape(x, (x.shape[0], self.input_resolution[0], self.input_resolution[1], self.dim))
198-
# Identity - no operation needed
199-
# pass
200-
# return x
201-
202195
# Patch embed: convert back to sequence format
203196
x = self.patch_embed(x)
204-
# import pdb
205-
206-
# pdb.set_trace()
207197

208198
x = ttnn.reshape(x, (x.shape[0], self.input_resolution[0] * self.input_resolution[1], self.dim))
209-
# return x
210-
211-
# import pdb
212-
213-
# pdb.set_trace()
214199

215200
# Add residual connection
216201
x = ttnn.add(x, shortcut)

models/experimental/SSR/tt/channel_attention.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,48 +41,20 @@ def forward(self, x):
4141
ends=[original_shape[0], 1, 1, 180], # End indices - slice to 180 in last dim
4242
steps=[1, 1, 1, 1], # Step size for each dimension
4343
)
44-
# TODO: find ways to generalise for all inputs, setting program config messes up the multi batch runs..
45-
# Matrix multiplication 1: [1, 180] @ [180, 6] = [1, 6]
46-
# program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig(
47-
# compute_with_storage_grid_size=(10, 1), # or (5, 2) to accommodate 10 batches
48-
# in0_block_w=6, # Keep same as your calculation
49-
# out_subblock_h=1,
50-
# out_subblock_w=1,
51-
# per_core_M=1, # Each core handles 1 batch worth of M dimension
52-
# per_core_N=1,
53-
# fuse_batch=True,
54-
# fused_activation=None,
55-
# mcast_in0=False,
56-
# )
44+
5745
x = ttnn.linear(
5846
x,
5947
self.conv1_weight,
6048
bias=self.conv1_bias,
6149
memory_config=self.memory_config,
62-
# program_config=program_config,
6350
activation="relu",
64-
# compute_kernel_config=compute_kernel_config, # set to HiFi2 to improve accuracy
6551
)
6652

67-
# Matrix multiplication 2: [1, 6] @ [6, 180] = [1, 180]
68-
# program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig(
69-
# compute_with_storage_grid_size=(6, 6),
70-
# in0_block_w=1, # 32(6 padded) / 32(1 tile size) = 1
71-
# out_subblock_h=1,
72-
# out_subblock_w=1,
73-
# per_core_M=1,
74-
# per_core_N=1,
75-
# fuse_batch=True,
76-
# fused_activation=None,
77-
# mcast_in0=False,
78-
# )
7953
x = ttnn.linear(
8054
x,
8155
self.conv2_weight,
8256
bias=self.conv2_bias,
8357
memory_config=self.memory_config,
84-
# program_config=program_config,
85-
# compute_kernel_config=compute_kernel_config, # set to HiFi2 to improve accuracy
8658
)
8759

8860
# Sigmoid activation

models/experimental/SSR/tt/mask_token_inference.py

Lines changed: 0 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,6 @@
88

99

1010
class TTMaskTokenInference(LightweightModule):
11-
# def __init__(
12-
# self, device, parameters, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0
13-
# ):
14-
# self.device = device
15-
# self.dim = dim
16-
# self.num_heads = num_heads
17-
# self.head_dim = dim // num_heads
18-
# self.scale = qk_scale or (self.head_dim**-0.5)
19-
20-
# # Layer norm parameters (would need to be loaded from state dict)
21-
# self.norm_weight = parameters["norm"]["weight"] # ttnn tensor for layer norm weight
22-
# self.norm_bias = parameters["norm"]["bias"] # ttnn tensor for layer norm bias
23-
24-
# # Linear layer weights (would need to be preprocessed and loaded)
25-
# self.q_weight = parameters["q"]["weight"] # ttnn tensor for query projection
26-
# self.k_weight = parameters["k"]["weight"] # ttnn tensor for key projection
27-
# self.v_weight = parameters["v"]["weight"] # ttnn tensor for value projection
28-
# self.proj_weight = parameters["proj"]["weight"] # ttnn tensor for output projection
29-
30-
# self.q_bias = parameters["q"]["bias"] if qkv_bias else None
31-
# self.k_bias = parameters["k"]["bias"] if qkv_bias else None
32-
# self.v_bias = parameters["v"]["bias"] if qkv_bias else None
33-
# self.proj_bias = parameters["proj"]["bias"]
34-
35-
# # Scale tensor
36-
# scale_tensor = torch.tensor(self.scale).view(1, 1, 1, 1)
37-
# self.tt_scale = ttnn.from_torch(scale_tensor, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT)
38-
3911
def __init__(
4012
self, device, parameters, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0
4113
):
@@ -61,71 +33,6 @@ def __init__(
6133
scale_tensor = torch.tensor(self.scale).view(1, 1, 1, 1)
6234
self.tt_scale = ttnn.from_torch(scale_tensor, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT)
6335

64-
# def __call__(self, fea):
65-
# B, N, C = fea.shape
66-
67-
# # Layer normalization
68-
# x = ttnn.layer_norm(fea, weight=self.norm_weight, bias=self.norm_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
69-
# fea_skip = fea
70-
# fea_skip = ttnn.reallocate(fea_skip, memory_config=ttnn.DRAM_MEMORY_CONFIG)
71-
# ttnn.deallocate(fea)
72-
73-
# # Split into classification token and feature tokens
74-
# # T_s: classification token [B, 1, C]
75-
# # F_s: feature tokens [B, N-1, C]
76-
# T_s = ttnn.slice(x, [0, 0, 0], [B, 1, C])
77-
# F_s = ttnn.slice(x, [0, 1, 0], [B, N, C])
78-
# ttnn.deallocate(x)
79-
80-
# # Query from feature tokens
81-
# q = ttnn.linear(F_s, self.q_weight, bias=self.q_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
82-
# # q = ttnn.reshape(q, (B, N - 1, self.num_heads, self.head_dim))
83-
# # q = ttnn.permute(q, (0, 2, 1, 3))
84-
85-
# # Key from classification token
86-
# k = ttnn.linear(T_s, self.k_weight, bias=self.k_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
87-
# # k = ttnn.reshape(k, (B, 1, self.num_heads, self.head_dim))
88-
# # k = ttnn.permute(k, (0, 2, 1, 3))
89-
90-
# # Value from classification token
91-
# v = ttnn.linear(T_s, self.v_weight, bias=self.v_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
92-
# # v = ttnn.reshape(v, (B, 1, self.num_heads, self.head_dim))
93-
# # v = ttnn.permute(v, (0, 2, 1, 3))
94-
95-
# # Attention computation: q @ k.T
96-
# k_transposed = ttnn.transpose(k, -2, -1)
97-
# attn = ttnn.matmul(q, k_transposed)
98-
99-
# # Scale attention scores
100-
# attn = ttnn.multiply(attn, self.tt_scale)
101-
102-
# # Apply sigmoid instead of softmax
103-
# attn = ttnn.sigmoid(attn)
104-
105-
# # Apply attention dropout (if needed, would require custom implementation)
106-
# # attn = apply_dropout(attn, attn_drop)
107-
108-
# # Compute attention output
109-
# infer_fea = ttnn.matmul(attn, v)
110-
111-
# # Reshape back to [B, N-1, C]
112-
# infer_fea = ttnn.permute(infer_fea, (0, 2, 1, 3))
113-
# infer_fea = ttnn.to_layout(infer_fea, layout=ttnn.ROW_MAJOR_LAYOUT)
114-
# infer_fea = ttnn.reshape(infer_fea, (B, N - 1, C))
115-
# infer_fea = ttnn.to_layout(infer_fea, layout=ttnn.TILE_LAYOUT)
116-
117-
# # Output projection
118-
# infer_fea = ttnn.linear(infer_fea, self.proj_weight, bias=self.proj_bias)
119-
120-
# # Apply projection dropout (if needed)
121-
# # infer_fea = apply_dropout(infer_fea, proj_drop)
122-
123-
# # Residual connection with original feature tokens
124-
# original_features = ttnn.slice(fea_skip, [0, 1, 0], [B, N, C])
125-
# infer_fea = ttnn.add(infer_fea, original_features)
126-
127-
# return infer_fea
128-
12936
def __call__(self, fea):
13037
B, N, C = fea.shape
13138

models/experimental/SSR/tt/mlp.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,8 @@ def __init__(self, device, in_features, hidden_features=None, out_features=None,
1818
self.fc2_bias = parameters["fc2"]["bias"]
1919

2020
def forward(self, x):
21-
# Debug prints for forward arguments
2221
if x.memory_config().buffer_type != ttnn.BufferType.L1:
2322
x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG)
24-
# First linear layer
25-
# program_config = matmul_config(
26-
# x.shape[-2], x.shape[-1], self.fc1_bias.shape[-1], (8, 8), fused_activation=(ttnn.UnaryOpType.GELU, True)
27-
# )
2823
x = ttnn.linear(
2924
x,
3025
self.fc1_weight,
@@ -34,8 +29,6 @@ def forward(self, x):
3429
activation="gelu",
3530
)
3631

37-
# program_config = matmul_config(x.shape[-2], x.shape[-1], self.fc2_bias.shape[-1], (8, 8))
38-
# Second linear layer
3932
x = ttnn.linear(
4033
x,
4134
self.fc2_weight,

models/experimental/SSR/tt/patch_embed_tile_refinement.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,9 @@ def forward(self, x):
5959
Returns:
6060
Output tensor of shape [batch, num_patches, embed_dim]
6161
"""
62-
# batch_size, channels, height_width = x.shape
6362

64-
# Flatten spatial dimensions and transpose
65-
# x.flatten(2) -> [batch, channels, height*width]
66-
# .transpose(1, 2) -> [batch, height*width, channels]
6763
if x.is_sharded():
6864
x = ttnn.to_memory_config(x, ttnn.DRAM_MEMORY_CONFIG)
69-
# x = ttnn.reshape(x, (batch_size, channels, height_width))
70-
# x = ttnn.transpose(x, 1, 2) # [batch, height*width, channels]
7165

7266
# Apply normalization if available
7367
if self.norm_weight is not None:

models/experimental/SSR/tt/tile_refinement.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,6 @@ def forward_features(self, x):
145145
"""Forward pass through transformer layers"""
146146
x_size = (self.h, self.w)
147147

148-
# # Calculate attention mask
149-
# attn_mask = self.calculate_mask(x_size)
150-
# # params = {
151-
# # "attn_mask": attn_mask,
152-
# # "rpi_sa": self.parameters.relative_position_index_SA,
153-
# # "rpi_oca": self.parameters.relative_position_index_OCA,
154-
# # }
155-
156148
# Patch embedding
157149
x = self.patch_embed(x)
158150

@@ -171,7 +163,6 @@ def forward_features(self, x):
171163
bias=self.parameters.norm.bias,
172164
memory_config=self.memory_config,
173165
)
174-
# return x
175166

176167
# Patch unembedding
177168
x = self.patch_unembed(x, x_size)
@@ -265,7 +256,6 @@ def forward(self, x):
265256
self.mean = ttnn.to_layout(self.mean, ttnn.TILE_LAYOUT)
266257
x = ttnn.subtract(x, self.mean, memory_config=self.memory_config)
267258
x = ttnn.multiply(x, self.img_range, memory_config=self.memory_config)
268-
# return x
269259

270260
if self.upsampler == "pixelshuffle":
271261
# Shallow feature extraction
@@ -298,8 +288,6 @@ def forward(self, x):
298288
batch_size=x.shape[0],
299289
input_height=x.shape[1],
300290
input_width=x.shape[2],
301-
# dilation= [1, 1],
302-
# groups = 1,
303291
conv_config=self.conv_config,
304292
compute_config=self.compute_config,
305293
return_output_dim=False,
@@ -338,7 +326,6 @@ def forward(self, x):
338326
batch_size=fea.shape[0],
339327
input_height=fea.shape[1],
340328
input_width=fea.shape[2],
341-
# dilation= [1, 1],
342329
conv_config=self.conv_afterbody_config,
343330
compute_config=self.compute_config,
344331
return_output_dim=False,
@@ -375,8 +362,6 @@ def forward(self, x):
375362
x = ttnn.leaky_relu(x, negative_slope=0.01, memory_config=ttnn.DRAM_MEMORY_CONFIG)
376363
x = ttnn.reshape(x, [batch_size, 64, 64, 64]) # TODO
377364

378-
# x = ttnn.permute(x, (0, 3, 1, 2))
379-
380365
# Upsampling
381366
x = self.upsample(x, self.parameters["upsample"])
382367

0 commit comments

Comments
 (0)