Skip to content

Commit 6958b5c

Browse files
ign-amanksign-krishnanand
authored andcommitted
Refactors patch_merging to take weights from test
1 parent 46ad202 commit 6958b5c

File tree

4 files changed

+58
-103
lines changed

4 files changed

+58
-103
lines changed

models/experimental/SSR/tests/test_basic_block.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,12 @@ def to_2tuple(x):
2424
return (x, x)
2525

2626

27-
def create_basic_layer_preprocessor(device):
27+
def create_basic_layer_preprocessor(device, dim):
2828
def custom_preprocessor(torch_model, name, ttnn_module_args):
2929
params = {"blocks": {}}
3030

3131
# Process each transformer block
3232
for i, block in enumerate(torch_model.blocks):
33-
# relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
3433
params["blocks"][i] = preprocess_model_parameters(
3534
initialize_model=lambda: block,
3635
custom_preprocessor=create_swin_transformer_block_preprocessor(device),
@@ -41,7 +40,7 @@ def custom_preprocessor(torch_model, name, ttnn_module_args):
4140
if torch_model.downsample is not None:
4241
params["downsample"] = preprocess_model_parameters(
4342
initialize_model=lambda: torch_model.downsample,
44-
custom_preprocessor=create_patch_merging_preprocessor(device),
43+
custom_preprocessor=create_patch_merging_preprocessor(device, dim),
4544
device=device,
4645
)
4746

@@ -88,7 +87,7 @@ def test_basic_layer(device, batch_size, input_resolution, dim, depth, num_heads
8887
# Create ttnn model
8988
params = preprocess_model_parameters(
9089
initialize_model=lambda: ref_layer,
91-
custom_preprocessor=create_basic_layer_preprocessor(device),
90+
custom_preprocessor=create_basic_layer_preprocessor(device, dim),
9291
device=device,
9392
)
9493

models/experimental/SSR/tests/test_patch_merging.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,31 @@
1212
from models.experimental.SSR.tt.patch_merging import TTPatchMerging
1313

1414

15-
def create_patch_merging_preprocessor(device):
15+
def create_patch_merging_preprocessor(device, dim):
1616
def custom_preprocessor(torch_model, name, ttnn_module_args):
1717
params = {}
1818

19+
# Create conv kernels for patch merging (same as in forward pass)
20+
kernel_top_left = torch.zeros(dim, 1, 2, 2, dtype=torch.bfloat16)
21+
kernel_top_left[:, 0, 0, 0] = 1.0
22+
23+
kernel_bottom_left = torch.zeros(dim, 1, 2, 2, dtype=torch.bfloat16)
24+
kernel_bottom_left[:, 0, 1, 0] = 1.0
25+
26+
kernel_top_right = torch.zeros(dim, 1, 2, 2, dtype=torch.bfloat16)
27+
kernel_top_right[:, 0, 0, 1] = 1.0
28+
29+
kernel_bottom_right = torch.zeros(dim, 1, 2, 2, dtype=torch.bfloat16)
30+
kernel_bottom_right[:, 0, 1, 1] = 1.0
31+
32+
# Convert to TTNN tensors
33+
params["conv_kernels"] = {
34+
"top_left": ttnn.from_torch(kernel_top_left, device=device),
35+
"bottom_left": ttnn.from_torch(kernel_bottom_left, device=device),
36+
"top_right": ttnn.from_torch(kernel_top_right, device=device),
37+
"bottom_right": ttnn.from_torch(kernel_bottom_right, device=device),
38+
}
39+
1940
# Linear reduction layer
2041
params["reduction"] = {
2142
"weight": ttnn.from_torch(
@@ -76,7 +97,7 @@ def test_patch_merging(device, batch_size, input_resolution, dim):
7697
# Create ttnn model
7798
params = preprocess_model_parameters(
7899
initialize_model=lambda: ref_layer,
79-
custom_preprocessor=create_patch_merging_preprocessor(device),
100+
custom_preprocessor=create_patch_merging_preprocessor(device, dim),
80101
device=device,
81102
)
82103

models/experimental/SSR/tests/test_tile_selection.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from models.utility_functions import tt2torch_tensor, comp_pcc
1414

1515

16-
def create_tile_selection_preprocessor(device):
16+
def create_tile_selection_preprocessor(device, dim=96):
1717
def custom_preprocessor(torch_model, name, ttnn_module_args):
1818
parameters = {}
1919

@@ -36,9 +36,10 @@ def custom_preprocessor(torch_model, name, ttnn_module_args):
3636
# Handle encoder layers - delegate to existing TTBasicLayer preprocessor
3737
if hasattr(torch_model, "layers"):
3838
for i, layer in enumerate(torch_model.layers):
39+
layer_dim = int(dim * 2**i)
3940
layer_params = preprocess_model_parameters(
4041
initialize_model=lambda l=layer: l,
41-
custom_preprocessor=create_basic_layer_preprocessor(device),
42+
custom_preprocessor=create_basic_layer_preprocessor(device, layer_dim),
4243
device=device,
4344
)
4445
parameters[f"layers.{i}"] = layer_params
@@ -127,16 +128,13 @@ def __init__(self, imgsz, patchsz, token_size, dim):
127128

128129
parameters = preprocess_model_parameters(
129130
initialize_model=lambda: ref_layer,
130-
custom_preprocessor=create_tile_selection_preprocessor(device),
131+
custom_preprocessor=create_tile_selection_preprocessor(device, dim),
131132
device=device,
132133
)
133134

134135
# Create TTNN implementation
135136
tt_layer = TTTileSelection(device=device, parameters=parameters, args=args, num_cls=num_cls)
136137

137-
# NCHW -> NHWC
138-
input_tensor = input_tensor.permute(0, 2, 3, 1)
139-
140138
# Convert input to TTNN
141139
tt_input = ttnn.from_torch(input_tensor, device=device, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16)
142140

models/experimental/SSR/tt/patch_merging.py

Lines changed: 28 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import ttnn
55
from models.common.lightweightmodule import LightweightModule
6-
import torch
76
from models.demos.deepseek_v3.utils.config_helpers import matmul_config
87

98

@@ -27,6 +26,11 @@ def __init__(
2726
self.norm_weight = parameters["norm"]["weight"]
2827
self.norm_bias = parameters["norm"]["bias"]
2928

29+
self.kernel_top_left = parameters["conv_kernels"]["top_left"]
30+
self.kernel_bottom_left = parameters["conv_kernels"]["bottom_left"]
31+
self.kernel_top_right = parameters["conv_kernels"]["top_right"]
32+
self.kernel_bottom_right = parameters["conv_kernels"]["bottom_right"]
33+
3034
def forward(self, input_tensor):
3135
"""
3236
Args:
@@ -44,96 +48,29 @@ def forward(self, input_tensor):
4448
input_tensor = ttnn.reshape(input_tensor, (B, H, W, C))
4549
x = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG)
4650

47-
kernel_top_left = torch.zeros(C, 1, 2, 2, dtype=torch.bfloat16)
48-
kernel_top_left[:, 0, 0, 0] = 1.0
49-
50-
kernel_bottom_left = torch.zeros(C, 1, 2, 2, dtype=torch.bfloat16)
51-
kernel_bottom_left[:, 0, 1, 0] = 1.0
52-
53-
kernel_top_right = torch.zeros(C, 1, 2, 2, dtype=torch.bfloat16)
54-
kernel_top_right[:, 0, 0, 1] = 1.0
55-
56-
kernel_bottom_right = torch.zeros(C, 1, 2, 2, dtype=torch.bfloat16)
57-
kernel_bottom_right[:, 0, 1, 1] = 1.0
58-
59-
# Convert to TTNN tensors
60-
tt_kernel_top_left = ttnn.from_torch(kernel_top_left, device=self.device)
61-
tt_kernel_bottom_left = ttnn.from_torch(kernel_bottom_left, device=self.device)
62-
tt_kernel_top_right = ttnn.from_torch(kernel_top_right, device=self.device)
63-
tt_kernel_bottom_right = ttnn.from_torch(kernel_bottom_right, device=self.device)
64-
65-
# Apply grouped convolutions for each patch
66-
x0 = ttnn.conv2d(
67-
input_tensor=x,
68-
weight_tensor=tt_kernel_top_left,
69-
in_channels=C,
70-
out_channels=C,
71-
device=self.device,
72-
kernel_size=(2, 2),
73-
stride=(2, 2),
74-
padding=(0, 0),
75-
groups=C, # Grouped convolution
76-
batch_size=B,
77-
input_height=H,
78-
input_width=W,
79-
conv_config=None,
80-
dtype=ttnn.bfloat16,
81-
memory_config=ttnn.DRAM_MEMORY_CONFIG,
82-
)
83-
84-
x1 = ttnn.conv2d(
85-
input_tensor=x,
86-
weight_tensor=tt_kernel_bottom_left,
87-
in_channels=C,
88-
out_channels=C,
89-
device=self.device,
90-
kernel_size=(2, 2),
91-
stride=(2, 2),
92-
padding=(0, 0),
93-
groups=C,
94-
batch_size=B,
95-
input_height=H,
96-
input_width=W,
97-
conv_config=None,
98-
dtype=ttnn.bfloat16,
99-
memory_config=ttnn.DRAM_MEMORY_CONFIG,
100-
)
101-
102-
x2 = ttnn.conv2d(
103-
input_tensor=x,
104-
weight_tensor=tt_kernel_top_right,
105-
in_channels=C,
106-
out_channels=C,
107-
device=self.device,
108-
kernel_size=(2, 2),
109-
stride=(2, 2),
110-
padding=(0, 0),
111-
groups=C,
112-
batch_size=B,
113-
input_height=H,
114-
input_width=W,
115-
conv_config=None,
116-
dtype=ttnn.bfloat16,
117-
memory_config=ttnn.DRAM_MEMORY_CONFIG,
118-
)
119-
120-
x3 = ttnn.conv2d(
121-
input_tensor=x,
122-
weight_tensor=tt_kernel_bottom_right,
123-
in_channels=C,
124-
out_channels=C,
125-
device=self.device,
126-
kernel_size=(2, 2),
127-
stride=(2, 2),
128-
padding=(0, 0),
129-
groups=C,
130-
batch_size=B,
131-
input_height=H,
132-
input_width=W,
133-
conv_config=None,
134-
dtype=ttnn.bfloat16,
135-
memory_config=ttnn.DRAM_MEMORY_CONFIG,
136-
)
51+
# Common convolution parameters
52+
conv_params = {
53+
"input_tensor": x,
54+
"in_channels": C,
55+
"out_channels": C,
56+
"device": self.device,
57+
"kernel_size": (2, 2),
58+
"stride": (2, 2),
59+
"padding": (0, 0),
60+
"groups": C, # Grouped convolution
61+
"batch_size": B,
62+
"input_height": H,
63+
"input_width": W,
64+
"conv_config": None,
65+
"dtype": ttnn.bfloat16,
66+
"memory_config": ttnn.DRAM_MEMORY_CONFIG,
67+
}
68+
69+
# Apply grouped convolutions for each patch, this is instead of a slice operation
70+
x0 = ttnn.conv2d(weight_tensor=self.kernel_top_left, **conv_params)
71+
x1 = ttnn.conv2d(weight_tensor=self.kernel_bottom_left, **conv_params)
72+
x2 = ttnn.conv2d(weight_tensor=self.kernel_top_right, **conv_params)
73+
x3 = ttnn.conv2d(weight_tensor=self.kernel_bottom_right, **conv_params)
13774

13875
ttnn.deallocate(x)
13976

0 commit comments

Comments
 (0)