|
4 | 4 |
|
5 | 5 |
|
6 | 6 | def window_partition_ttnn(x, window_size): |
7 | | - """TTNN implementation of window partitioning""" |
8 | | - b, h, w, c = x.shape |
| 7 | + """Partition into non-overlapping windows""" |
| 8 | + B, H, W, C = x.shape |
| 9 | + num_windows = (H // window_size) * (W // window_size) |
| 10 | + return ttnn.reshape(x, [B * num_windows, window_size, window_size, C], memory_config=ttnn.L1_MEMORY_CONFIG) |
9 | 11 |
|
10 | | - # Reshape: (b, h, w, c) -> (b, h//ws, ws, w//ws, ws, c) |
11 | | - reshaped = ttnn.reshape(x, (b, h // window_size, window_size, w // window_size, window_size, c)) |
12 | 12 |
|
13 | | - # Permute: (0, 1, 3, 2, 4, 5) -> group windows together |
14 | | - permuted = ttnn.permute(reshaped, (0, 1, 3, 2, 4, 5)) |
15 | | - |
16 | | - # Final reshape to get windows |
17 | | - windows = ttnn.reshape(permuted, (-1, window_size, window_size, c)) |
18 | | - |
19 | | - return windows |
20 | | - |
21 | | - |
22 | | -def window_reverse_ttnn(windows, window_size, h, w): |
23 | | - """TTNN implementation of window reverse""" |
24 | | - b = int(windows.shape[0] / (h * w / window_size / window_size)) |
25 | | - |
26 | | - # Reshape windows back to grid |
27 | | - reshaped = ttnn.reshape(windows, (b, h // window_size, w // window_size, window_size, window_size, -1)) |
28 | | - |
29 | | - # Permute back to original order |
30 | | - permuted = ttnn.permute(reshaped, (0, 1, 3, 2, 4, 5)) |
31 | | - |
32 | | - # Final reshape to original spatial dimensions |
33 | | - output = ttnn.reshape(permuted, (b, h, w, -1)) |
34 | | - |
35 | | - return output |
| 13 | +def window_reverse_ttnn(windows, window_size, H, W): |
| 14 | + B = windows.shape[0] // (H * W // window_size // window_size) |
| 15 | + return ttnn.reshape(windows, [B, H, W, -1], memory_config=ttnn.L1_MEMORY_CONFIG) |
36 | 16 |
|
37 | 17 |
|
38 | 18 | class TTOCAB(LightweightModule): |
@@ -118,124 +98,125 @@ def ttnn_rearrange_host(self, tensor, pattern_from, pattern_to, **kwargs): |
118 | 98 | def forward(self, x, x_size, rpi): |
119 | 99 | h, w = x_size |
120 | 100 | b, _, c = x.shape |
121 | | - |
122 | | - # Store shortcut connection |
123 | 101 | shortcut = x |
124 | 102 |
|
125 | | - # Layer normalization - handle padded dimensions |
126 | | - x = ttnn.layer_norm(x, weight=self.norm1_weight, bias=self.norm1_bias) |
127 | | - |
128 | | - # Reshape to spatial format |
129 | | - x = ttnn.reshape(x, (b, h, w, c)) |
130 | | - |
131 | | - # QKV projection |
132 | | - qkv = ttnn.linear(x, self.qkv_weight, bias=self.qkv_bias) |
133 | | - qkv = ttnn.reshape(qkv, (b, h, w, 3, c)) |
134 | | - qkv = ttnn.permute(qkv, (3, 0, 4, 1, 2)) # 3, b, c, h, w |
135 | | - |
136 | | - # Split Q, K, V using slicing |
137 | | - q = ttnn.slice(qkv, (0, 0, 0, 0, 0), (1, b, c, h, w)) |
138 | | - q = ttnn.squeeze(q, 0) # Remove first dimension |
139 | | - q = ttnn.permute(q, (0, 2, 3, 1)) # b, h, w, c |
140 | | - |
141 | | - k = ttnn.slice(qkv, (1, 0, 0, 0, 0), (2, b, c, h, w)) |
142 | | - k = ttnn.squeeze(k, 0) |
143 | | - |
144 | | - v = ttnn.slice(qkv, (2, 0, 0, 0, 0), (3, b, c, h, w)) |
145 | | - v = ttnn.squeeze(v, 0) |
146 | | - |
147 | | - # Concatenate K and V for unfold operation |
148 | | - kv = ttnn.concat([k, v], dim=1) # b, 2*c, h, w |
149 | | - |
150 | | - # Window partition for Q |
151 | | - q_windows = window_partition_ttnn(q, self.window_size) |
152 | | - q_windows = ttnn.reshape(q_windows, (-1, self.window_size * self.window_size, c)) |
153 | | - |
154 | | - # Host-side unfold operation for KV |
155 | | - kv_torch = ttnn.to_torch(kv) |
156 | | - kv_windows_torch = self._unfold(kv_torch) # b, c*w*w, nw |
157 | | - kv_windows = ttnn.from_torch( |
158 | | - kv_windows_torch, |
159 | | - dtype=kv.dtype, |
160 | | - layout=ttnn.ROW_MAJOR_LAYOUT, |
161 | | - device=self.device, |
162 | | - memory_config=ttnn.DRAM_MEMORY_CONFIG, |
| 103 | + # Layer normalization |
| 104 | + x = ttnn.layer_norm(x, weight=self.norm1_weight, bias=self.norm1_bias, memory_config=ttnn.L1_MEMORY_CONFIG) |
| 105 | + |
| 106 | + # Fused QKV projection - use single linear operation |
| 107 | + qkv = ttnn.linear( |
| 108 | + x, |
| 109 | + self.qkv_weight, |
| 110 | + bias=self.qkv_bias, |
| 111 | + memory_config=ttnn.L1_MEMORY_CONFIG, |
| 112 | + dtype=ttnn.bfloat8_b, |
| 113 | + core_grid=ttnn.CoreGrid(y=8, x=8), |
163 | 114 | ) |
164 | | - |
165 | | - # Rearrange KV windows using host implementation |
166 | | - kv_windows = self.ttnn_rearrange_host( |
167 | | - kv_windows, |
168 | | - "b (nc ch owh oww) nw", |
169 | | - "nc (b nw) (owh oww) ch", |
170 | | - nc=2, |
171 | | - ch=c, |
172 | | - owh=self.overlap_win_size, |
173 | | - oww=self.overlap_win_size, |
| 115 | + tile_size = 32 |
| 116 | + # for 180 dim, 540 qkvshape[-1] :- head_size = 30, padded_head_size = 32 |
| 117 | + head_size = qkv.shape[-1] // (3 * self.num_heads) |
| 118 | + padded_head_size = ((head_size + tile_size - 1) // tile_size) * tile_size |
| 119 | + pad = padded_head_size != head_size |
| 120 | + if pad: |
| 121 | + qkv_torch = ttnn.to_torch(qkv) |
| 122 | + input_tensor_heads = torch.split(qkv_torch, head_size, dim=-1) |
| 123 | + input_tensor_heads = [ |
| 124 | + torch.nn.functional.pad(head, (0, padded_head_size - head_size), "constant", 0) |
| 125 | + for head in input_tensor_heads |
| 126 | + ] |
| 127 | + qkv = torch.cat(input_tensor_heads, dim=-1) |
| 128 | + qkv = ttnn.from_torch( |
| 129 | + qkv, |
| 130 | + device=self.device, |
| 131 | + dtype=ttnn.bfloat16, |
| 132 | + memory_config=ttnn.L1_MEMORY_CONFIG, |
| 133 | + layout=ttnn.TILE_LAYOUT, |
| 134 | + ) |
| 135 | + # Use transformer function for QKV splitting |
| 136 | + query, key, value = ttnn.transformer.split_query_key_value_and_split_heads( |
| 137 | + qkv, memory_config=ttnn.DRAM_MEMORY_CONFIG, num_heads=self.num_heads, transpose_key=False |
174 | 138 | ) |
| 139 | + ttnn.deallocate(qkv) |
175 | 140 |
|
176 | | - # Split K and V windows |
177 | | - k_windows = ttnn.slice( |
178 | | - kv_windows, (0, 0, 0, 0), (1, kv_windows.shape[1], kv_windows.shape[2], kv_windows.shape[3]) |
| 141 | + sdpa_program_config = ttnn.SDPAProgramConfig( |
| 142 | + compute_with_storage_grid_size=[8, 7], |
| 143 | + q_chunk_size=512, |
| 144 | + k_chunk_size=512, |
| 145 | + exp_approx_mode=False, |
179 | 146 | ) |
180 | | - k_windows = ttnn.squeeze(k_windows, 0) |
181 | | - |
182 | | - v_windows = ttnn.slice( |
183 | | - kv_windows, (1, 0, 0, 0), (2, kv_windows.shape[1], kv_windows.shape[2], kv_windows.shape[3]) |
| 147 | + compute_kernel_config = ttnn.init_device_compute_kernel_config( |
| 148 | + self.device.arch(), |
| 149 | + math_fidelity=ttnn.MathFidelity.LoFi, |
| 150 | + math_approx_mode=True, |
| 151 | + fp32_dest_acc_en=False, |
| 152 | + packer_l1_acc=False, |
184 | 153 | ) |
185 | | - v_windows = ttnn.squeeze(v_windows, 0) |
186 | | - |
187 | | - # Multi-head attention computation |
188 | | - b_, nq, _ = q_windows.shape |
189 | | - _, n, _ = k_windows.shape |
190 | | - d = self.dim // self.num_heads |
191 | | - |
192 | | - # Reshape for multi-head attention |
193 | | - q = ttnn.reshape(q_windows, (b_, nq, self.num_heads, d)) |
194 | | - q = ttnn.permute(q, (0, 2, 1, 3)) # nw*b, nH, nq, d |
195 | | - |
196 | | - k = ttnn.reshape(k_windows, (b_, n, self.num_heads, d)) |
197 | | - k = ttnn.permute(k, (0, 2, 1, 3)) # nw*b, nH, n, d |
198 | | - |
199 | | - v = ttnn.reshape(v_windows, (b_, n, self.num_heads, d)) |
200 | | - v = ttnn.permute(v, (0, 2, 1, 3)) # nw*b, nH, n, d |
201 | | - |
202 | | - q = ttnn.to_layout(q, ttnn.TILE_LAYOUT) |
203 | | - k = ttnn.to_layout(k, ttnn.TILE_LAYOUT) |
204 | | - v = ttnn.to_layout(v, ttnn.TILE_LAYOUT) |
205 | | - |
206 | | - # Scale queries |
207 | | - q = ttnn.multiply(q, self.scale) |
208 | | - |
209 | | - # Attention computation |
210 | | - k_transposed = ttnn.transpose(k, -2, -1) |
211 | | - attn = ttnn.matmul(q, k_transposed) |
212 | | - |
213 | | - # Apply softmax |
214 | | - attn = ttnn.softmax(attn, dim=-1) |
215 | 154 |
|
216 | | - # Apply attention to values |
217 | | - attn_output = ttnn.matmul(attn, v) |
218 | | - attn_output = ttnn.transpose(attn_output, 1, 2) |
219 | | - attn_output = ttnn.reshape(attn_output, (b_, nq, self.dim)) |
220 | | - |
221 | | - # Merge windows |
222 | | - attn_windows = ttnn.reshape(attn_output, (-1, self.window_size, self.window_size, self.dim)) |
223 | | - x = window_reverse_ttnn(attn_windows, self.window_size, h, w) |
224 | | - x = ttnn.reshape(x, (b, h * w, self.dim)) |
| 155 | + # Use optimized scaled dot product attention |
| 156 | + attention_output = ttnn.transformer.scaled_dot_product_attention( |
| 157 | + query, |
| 158 | + key, |
| 159 | + value, |
| 160 | + is_causal=False, |
| 161 | + scale=self.scale, |
| 162 | + program_config=sdpa_program_config, |
| 163 | + compute_kernel_config=compute_kernel_config, |
| 164 | + memory_config=ttnn.L1_MEMORY_CONFIG, |
| 165 | + ) |
225 | 166 |
|
226 | | - # Projection and residual connection |
227 | | - x = ttnn.linear(x, self.proj_weight, bias=self.proj_bias) |
| 167 | + # Deallocate intermediate tensors |
| 168 | + ttnn.deallocate(query) |
| 169 | + ttnn.deallocate(key) |
| 170 | + ttnn.deallocate(value) |
| 171 | + # Use transformer function for head concatenation |
| 172 | + context_layer = ttnn.transformer.concatenate_heads( |
| 173 | + attention_output, |
| 174 | + memory_config=ttnn.L1_MEMORY_CONFIG, |
| 175 | + ) |
| 176 | + ttnn.deallocate(attention_output) |
| 177 | + |
| 178 | + # import pdb; pdb.set_trace() |
| 179 | + if pad: |
| 180 | + # remove padding |
| 181 | + context_layer = ttnn.to_torch(context_layer)[..., : self.dim] # slice to 180 and remove padding |
| 182 | + context_layer = ttnn.from_torch( |
| 183 | + context_layer, |
| 184 | + device=self.device, |
| 185 | + dtype=ttnn.bfloat16, |
| 186 | + memory_config=ttnn.L1_MEMORY_CONFIG, |
| 187 | + layout=ttnn.TILE_LAYOUT, |
| 188 | + ) |
| 189 | + # Reshape back to original format |
| 190 | + x = ttnn.reshape(context_layer, (b, h * w, self.dim), memory_config=ttnn.L1_MEMORY_CONFIG) |
| 191 | + |
| 192 | + # Output projection and residual |
| 193 | + x = ttnn.linear( |
| 194 | + x, |
| 195 | + self.proj_weight, |
| 196 | + bias=self.proj_bias, |
| 197 | + memory_config=ttnn.L1_MEMORY_CONFIG, |
| 198 | + core_grid=ttnn.CoreGrid(y=8, x=8), |
| 199 | + ) |
228 | 200 | x = ttnn.add(x, shortcut) |
229 | 201 |
|
230 | | - # MLP block |
231 | | - x = ttnn.layer_norm(x, weight=self.norm2_weight, bias=self.norm2_bias) |
| 202 | + x = ttnn.layer_norm(x, weight=self.norm2_weight, bias=self.norm2_bias, memory_config=ttnn.L1_MEMORY_CONFIG) |
232 | 203 |
|
233 | | - # MLP forward pass |
234 | | - mlp_out = ttnn.linear(x, self.mlp_fc1_weight, bias=self.mlp_fc1_bias) |
235 | | - mlp_out = ttnn.gelu(mlp_out) |
236 | | - mlp_out = ttnn.linear(mlp_out, self.mlp_fc2_weight, bias=self.mlp_fc2_bias) |
| 204 | + mlp_out = ttnn.linear( |
| 205 | + x, |
| 206 | + self.mlp_fc1_weight, |
| 207 | + bias=self.mlp_fc1_bias, |
| 208 | + memory_config=ttnn.L1_MEMORY_CONFIG, |
| 209 | + dtype=ttnn.bfloat8_b, |
| 210 | + core_grid=ttnn.CoreGrid(y=8, x=8), |
| 211 | + activation="gelu", |
| 212 | + ) |
| 213 | + mlp_out = ttnn.linear( |
| 214 | + mlp_out, |
| 215 | + self.mlp_fc2_weight, |
| 216 | + bias=self.mlp_fc2_bias, |
| 217 | + memory_config=ttnn.L1_MEMORY_CONFIG, |
| 218 | + core_grid=ttnn.CoreGrid(y=8, x=8), |
| 219 | + ) |
237 | 220 |
|
238 | | - # Final residual connection |
239 | 221 | x = ttnn.add(x, mlp_out) |
240 | | - |
241 | 222 | return x |
0 commit comments