Skip to content

Commit 9e79d30

Browse files
OCAB opt
1 parent 1f3d0fe commit 9e79d30

File tree

1 file changed

+114
-133
lines changed

1 file changed

+114
-133
lines changed

models/experimental/SSR/tt/OCAB.py

Lines changed: 114 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,15 @@
44

55

66
def window_partition_ttnn(x, window_size):
7-
"""TTNN implementation of window partitioning"""
8-
b, h, w, c = x.shape
7+
"""Partition into non-overlapping windows"""
8+
B, H, W, C = x.shape
9+
num_windows = (H // window_size) * (W // window_size)
10+
return ttnn.reshape(x, [B * num_windows, window_size, window_size, C], memory_config=ttnn.L1_MEMORY_CONFIG)
911

10-
# Reshape: (b, h, w, c) -> (b, h//ws, ws, w//ws, ws, c)
11-
reshaped = ttnn.reshape(x, (b, h // window_size, window_size, w // window_size, window_size, c))
1212

13-
# Permute: (0, 1, 3, 2, 4, 5) -> group windows together
14-
permuted = ttnn.permute(reshaped, (0, 1, 3, 2, 4, 5))
15-
16-
# Final reshape to get windows
17-
windows = ttnn.reshape(permuted, (-1, window_size, window_size, c))
18-
19-
return windows
20-
21-
22-
def window_reverse_ttnn(windows, window_size, h, w):
23-
"""TTNN implementation of window reverse"""
24-
b = int(windows.shape[0] / (h * w / window_size / window_size))
25-
26-
# Reshape windows back to grid
27-
reshaped = ttnn.reshape(windows, (b, h // window_size, w // window_size, window_size, window_size, -1))
28-
29-
# Permute back to original order
30-
permuted = ttnn.permute(reshaped, (0, 1, 3, 2, 4, 5))
31-
32-
# Final reshape to original spatial dimensions
33-
output = ttnn.reshape(permuted, (b, h, w, -1))
34-
35-
return output
13+
def window_reverse_ttnn(windows, window_size, H, W):
14+
B = windows.shape[0] // (H * W // window_size // window_size)
15+
return ttnn.reshape(windows, [B, H, W, -1], memory_config=ttnn.L1_MEMORY_CONFIG)
3616

3717

3818
class TTOCAB(LightweightModule):
@@ -118,124 +98,125 @@ def ttnn_rearrange_host(self, tensor, pattern_from, pattern_to, **kwargs):
11898
def forward(self, x, x_size, rpi):
11999
h, w = x_size
120100
b, _, c = x.shape
121-
122-
# Store shortcut connection
123101
shortcut = x
124102

125-
# Layer normalization - handle padded dimensions
126-
x = ttnn.layer_norm(x, weight=self.norm1_weight, bias=self.norm1_bias)
127-
128-
# Reshape to spatial format
129-
x = ttnn.reshape(x, (b, h, w, c))
130-
131-
# QKV projection
132-
qkv = ttnn.linear(x, self.qkv_weight, bias=self.qkv_bias)
133-
qkv = ttnn.reshape(qkv, (b, h, w, 3, c))
134-
qkv = ttnn.permute(qkv, (3, 0, 4, 1, 2)) # 3, b, c, h, w
135-
136-
# Split Q, K, V using slicing
137-
q = ttnn.slice(qkv, (0, 0, 0, 0, 0), (1, b, c, h, w))
138-
q = ttnn.squeeze(q, 0) # Remove first dimension
139-
q = ttnn.permute(q, (0, 2, 3, 1)) # b, h, w, c
140-
141-
k = ttnn.slice(qkv, (1, 0, 0, 0, 0), (2, b, c, h, w))
142-
k = ttnn.squeeze(k, 0)
143-
144-
v = ttnn.slice(qkv, (2, 0, 0, 0, 0), (3, b, c, h, w))
145-
v = ttnn.squeeze(v, 0)
146-
147-
# Concatenate K and V for unfold operation
148-
kv = ttnn.concat([k, v], dim=1) # b, 2*c, h, w
149-
150-
# Window partition for Q
151-
q_windows = window_partition_ttnn(q, self.window_size)
152-
q_windows = ttnn.reshape(q_windows, (-1, self.window_size * self.window_size, c))
153-
154-
# Host-side unfold operation for KV
155-
kv_torch = ttnn.to_torch(kv)
156-
kv_windows_torch = self._unfold(kv_torch) # b, c*w*w, nw
157-
kv_windows = ttnn.from_torch(
158-
kv_windows_torch,
159-
dtype=kv.dtype,
160-
layout=ttnn.ROW_MAJOR_LAYOUT,
161-
device=self.device,
162-
memory_config=ttnn.DRAM_MEMORY_CONFIG,
103+
# Layer normalization
104+
x = ttnn.layer_norm(x, weight=self.norm1_weight, bias=self.norm1_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
105+
106+
# Fused QKV projection - use single linear operation
107+
qkv = ttnn.linear(
108+
x,
109+
self.qkv_weight,
110+
bias=self.qkv_bias,
111+
memory_config=ttnn.L1_MEMORY_CONFIG,
112+
dtype=ttnn.bfloat8_b,
113+
core_grid=ttnn.CoreGrid(y=8, x=8),
163114
)
164-
165-
# Rearrange KV windows using host implementation
166-
kv_windows = self.ttnn_rearrange_host(
167-
kv_windows,
168-
"b (nc ch owh oww) nw",
169-
"nc (b nw) (owh oww) ch",
170-
nc=2,
171-
ch=c,
172-
owh=self.overlap_win_size,
173-
oww=self.overlap_win_size,
115+
tile_size = 32
116+
# for 180 dim, 540 qkvshape[-1] :- head_size = 30, padded_head_size = 32
117+
head_size = qkv.shape[-1] // (3 * self.num_heads)
118+
padded_head_size = ((head_size + tile_size - 1) // tile_size) * tile_size
119+
pad = padded_head_size != head_size
120+
if pad:
121+
qkv_torch = ttnn.to_torch(qkv)
122+
input_tensor_heads = torch.split(qkv_torch, head_size, dim=-1)
123+
input_tensor_heads = [
124+
torch.nn.functional.pad(head, (0, padded_head_size - head_size), "constant", 0)
125+
for head in input_tensor_heads
126+
]
127+
qkv = torch.cat(input_tensor_heads, dim=-1)
128+
qkv = ttnn.from_torch(
129+
qkv,
130+
device=self.device,
131+
dtype=ttnn.bfloat16,
132+
memory_config=ttnn.L1_MEMORY_CONFIG,
133+
layout=ttnn.TILE_LAYOUT,
134+
)
135+
# Use transformer function for QKV splitting
136+
query, key, value = ttnn.transformer.split_query_key_value_and_split_heads(
137+
qkv, memory_config=ttnn.DRAM_MEMORY_CONFIG, num_heads=self.num_heads, transpose_key=False
174138
)
139+
ttnn.deallocate(qkv)
175140

176-
# Split K and V windows
177-
k_windows = ttnn.slice(
178-
kv_windows, (0, 0, 0, 0), (1, kv_windows.shape[1], kv_windows.shape[2], kv_windows.shape[3])
141+
sdpa_program_config = ttnn.SDPAProgramConfig(
142+
compute_with_storage_grid_size=[8, 7],
143+
q_chunk_size=512,
144+
k_chunk_size=512,
145+
exp_approx_mode=False,
179146
)
180-
k_windows = ttnn.squeeze(k_windows, 0)
181-
182-
v_windows = ttnn.slice(
183-
kv_windows, (1, 0, 0, 0), (2, kv_windows.shape[1], kv_windows.shape[2], kv_windows.shape[3])
147+
compute_kernel_config = ttnn.init_device_compute_kernel_config(
148+
self.device.arch(),
149+
math_fidelity=ttnn.MathFidelity.LoFi,
150+
math_approx_mode=True,
151+
fp32_dest_acc_en=False,
152+
packer_l1_acc=False,
184153
)
185-
v_windows = ttnn.squeeze(v_windows, 0)
186-
187-
# Multi-head attention computation
188-
b_, nq, _ = q_windows.shape
189-
_, n, _ = k_windows.shape
190-
d = self.dim // self.num_heads
191-
192-
# Reshape for multi-head attention
193-
q = ttnn.reshape(q_windows, (b_, nq, self.num_heads, d))
194-
q = ttnn.permute(q, (0, 2, 1, 3)) # nw*b, nH, nq, d
195-
196-
k = ttnn.reshape(k_windows, (b_, n, self.num_heads, d))
197-
k = ttnn.permute(k, (0, 2, 1, 3)) # nw*b, nH, n, d
198-
199-
v = ttnn.reshape(v_windows, (b_, n, self.num_heads, d))
200-
v = ttnn.permute(v, (0, 2, 1, 3)) # nw*b, nH, n, d
201-
202-
q = ttnn.to_layout(q, ttnn.TILE_LAYOUT)
203-
k = ttnn.to_layout(k, ttnn.TILE_LAYOUT)
204-
v = ttnn.to_layout(v, ttnn.TILE_LAYOUT)
205-
206-
# Scale queries
207-
q = ttnn.multiply(q, self.scale)
208-
209-
# Attention computation
210-
k_transposed = ttnn.transpose(k, -2, -1)
211-
attn = ttnn.matmul(q, k_transposed)
212-
213-
# Apply softmax
214-
attn = ttnn.softmax(attn, dim=-1)
215154

216-
# Apply attention to values
217-
attn_output = ttnn.matmul(attn, v)
218-
attn_output = ttnn.transpose(attn_output, 1, 2)
219-
attn_output = ttnn.reshape(attn_output, (b_, nq, self.dim))
220-
221-
# Merge windows
222-
attn_windows = ttnn.reshape(attn_output, (-1, self.window_size, self.window_size, self.dim))
223-
x = window_reverse_ttnn(attn_windows, self.window_size, h, w)
224-
x = ttnn.reshape(x, (b, h * w, self.dim))
155+
# Use optimized scaled dot product attention
156+
attention_output = ttnn.transformer.scaled_dot_product_attention(
157+
query,
158+
key,
159+
value,
160+
is_causal=False,
161+
scale=self.scale,
162+
program_config=sdpa_program_config,
163+
compute_kernel_config=compute_kernel_config,
164+
memory_config=ttnn.L1_MEMORY_CONFIG,
165+
)
225166

226-
# Projection and residual connection
227-
x = ttnn.linear(x, self.proj_weight, bias=self.proj_bias)
167+
# Deallocate intermediate tensors
168+
ttnn.deallocate(query)
169+
ttnn.deallocate(key)
170+
ttnn.deallocate(value)
171+
# Use transformer function for head concatenation
172+
context_layer = ttnn.transformer.concatenate_heads(
173+
attention_output,
174+
memory_config=ttnn.L1_MEMORY_CONFIG,
175+
)
176+
ttnn.deallocate(attention_output)
177+
178+
# import pdb; pdb.set_trace()
179+
if pad:
180+
# remove padding
181+
context_layer = ttnn.to_torch(context_layer)[..., : self.dim] # slice to 180 and remove padding
182+
context_layer = ttnn.from_torch(
183+
context_layer,
184+
device=self.device,
185+
dtype=ttnn.bfloat16,
186+
memory_config=ttnn.L1_MEMORY_CONFIG,
187+
layout=ttnn.TILE_LAYOUT,
188+
)
189+
# Reshape back to original format
190+
x = ttnn.reshape(context_layer, (b, h * w, self.dim), memory_config=ttnn.L1_MEMORY_CONFIG)
191+
192+
# Output projection and residual
193+
x = ttnn.linear(
194+
x,
195+
self.proj_weight,
196+
bias=self.proj_bias,
197+
memory_config=ttnn.L1_MEMORY_CONFIG,
198+
core_grid=ttnn.CoreGrid(y=8, x=8),
199+
)
228200
x = ttnn.add(x, shortcut)
229201

230-
# MLP block
231-
x = ttnn.layer_norm(x, weight=self.norm2_weight, bias=self.norm2_bias)
202+
x = ttnn.layer_norm(x, weight=self.norm2_weight, bias=self.norm2_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
232203

233-
# MLP forward pass
234-
mlp_out = ttnn.linear(x, self.mlp_fc1_weight, bias=self.mlp_fc1_bias)
235-
mlp_out = ttnn.gelu(mlp_out)
236-
mlp_out = ttnn.linear(mlp_out, self.mlp_fc2_weight, bias=self.mlp_fc2_bias)
204+
mlp_out = ttnn.linear(
205+
x,
206+
self.mlp_fc1_weight,
207+
bias=self.mlp_fc1_bias,
208+
memory_config=ttnn.L1_MEMORY_CONFIG,
209+
dtype=ttnn.bfloat8_b,
210+
core_grid=ttnn.CoreGrid(y=8, x=8),
211+
activation="gelu",
212+
)
213+
mlp_out = ttnn.linear(
214+
mlp_out,
215+
self.mlp_fc2_weight,
216+
bias=self.mlp_fc2_bias,
217+
memory_config=ttnn.L1_MEMORY_CONFIG,
218+
core_grid=ttnn.CoreGrid(y=8, x=8),
219+
)
237220

238-
# Final residual connection
239221
x = ttnn.add(x, mlp_out)
240-
241222
return x

0 commit comments

Comments
 (0)