88
99
1010class TTMaskTokenInference (LightweightModule ):
11- # def __init__(
12- # self, device, parameters, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0
13- # ):
14- # self.device = device
15- # self.dim = dim
16- # self.num_heads = num_heads
17- # self.head_dim = dim // num_heads
18- # self.scale = qk_scale or (self.head_dim**-0.5)
19-
20- # # Layer norm parameters (would need to be loaded from state dict)
21- # self.norm_weight = parameters["norm"]["weight"] # ttnn tensor for layer norm weight
22- # self.norm_bias = parameters["norm"]["bias"] # ttnn tensor for layer norm bias
23-
24- # # Linear layer weights (would need to be preprocessed and loaded)
25- # self.q_weight = parameters["q"]["weight"] # ttnn tensor for query projection
26- # self.k_weight = parameters["k"]["weight"] # ttnn tensor for key projection
27- # self.v_weight = parameters["v"]["weight"] # ttnn tensor for value projection
28- # self.proj_weight = parameters["proj"]["weight"] # ttnn tensor for output projection
29-
30- # self.q_bias = parameters["q"]["bias"] if qkv_bias else None
31- # self.k_bias = parameters["k"]["bias"] if qkv_bias else None
32- # self.v_bias = parameters["v"]["bias"] if qkv_bias else None
33- # self.proj_bias = parameters["proj"]["bias"]
34-
35- # # Scale tensor
36- # scale_tensor = torch.tensor(self.scale).view(1, 1, 1, 1)
37- # self.tt_scale = ttnn.from_torch(scale_tensor, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT)
38-
3911 def __init__ (
4012 self , device , parameters , dim , num_heads = 1 , qkv_bias = False , qk_scale = None , attn_drop = 0.0 , proj_drop = 0.0
4113 ):
@@ -61,71 +33,6 @@ def __init__(
6133 scale_tensor = torch .tensor (self .scale ).view (1 , 1 , 1 , 1 )
6234 self .tt_scale = ttnn .from_torch (scale_tensor , dtype = ttnn .bfloat16 , device = device , layout = ttnn .TILE_LAYOUT )
6335
64- # def __call__(self, fea):
65- # B, N, C = fea.shape
66-
67- # # Layer normalization
68- # x = ttnn.layer_norm(fea, weight=self.norm_weight, bias=self.norm_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
69- # fea_skip = fea
70- # fea_skip = ttnn.reallocate(fea_skip, memory_config=ttnn.DRAM_MEMORY_CONFIG)
71- # ttnn.deallocate(fea)
72-
73- # # Split into classification token and feature tokens
74- # # T_s: classification token [B, 1, C]
75- # # F_s: feature tokens [B, N-1, C]
76- # T_s = ttnn.slice(x, [0, 0, 0], [B, 1, C])
77- # F_s = ttnn.slice(x, [0, 1, 0], [B, N, C])
78- # ttnn.deallocate(x)
79-
80- # # Query from feature tokens
81- # q = ttnn.linear(F_s, self.q_weight, bias=self.q_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
82- # # q = ttnn.reshape(q, (B, N - 1, self.num_heads, self.head_dim))
83- # # q = ttnn.permute(q, (0, 2, 1, 3))
84-
85- # # Key from classification token
86- # k = ttnn.linear(T_s, self.k_weight, bias=self.k_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
87- # # k = ttnn.reshape(k, (B, 1, self.num_heads, self.head_dim))
88- # # k = ttnn.permute(k, (0, 2, 1, 3))
89-
90- # # Value from classification token
91- # v = ttnn.linear(T_s, self.v_weight, bias=self.v_bias, memory_config=ttnn.L1_MEMORY_CONFIG)
92- # # v = ttnn.reshape(v, (B, 1, self.num_heads, self.head_dim))
93- # # v = ttnn.permute(v, (0, 2, 1, 3))
94-
95- # # Attention computation: q @ k.T
96- # k_transposed = ttnn.transpose(k, -2, -1)
97- # attn = ttnn.matmul(q, k_transposed)
98-
99- # # Scale attention scores
100- # attn = ttnn.multiply(attn, self.tt_scale)
101-
102- # # Apply sigmoid instead of softmax
103- # attn = ttnn.sigmoid(attn)
104-
105- # # Apply attention dropout (if needed, would require custom implementation)
106- # # attn = apply_dropout(attn, attn_drop)
107-
108- # # Compute attention output
109- # infer_fea = ttnn.matmul(attn, v)
110-
111- # # Reshape back to [B, N-1, C]
112- # infer_fea = ttnn.permute(infer_fea, (0, 2, 1, 3))
113- # infer_fea = ttnn.to_layout(infer_fea, layout=ttnn.ROW_MAJOR_LAYOUT)
114- # infer_fea = ttnn.reshape(infer_fea, (B, N - 1, C))
115- # infer_fea = ttnn.to_layout(infer_fea, layout=ttnn.TILE_LAYOUT)
116-
117- # # Output projection
118- # infer_fea = ttnn.linear(infer_fea, self.proj_weight, bias=self.proj_bias)
119-
120- # # Apply projection dropout (if needed)
121- # # infer_fea = apply_dropout(infer_fea, proj_drop)
122-
123- # # Residual connection with original feature tokens
124- # original_features = ttnn.slice(fea_skip, [0, 1, 0], [B, N, C])
125- # infer_fea = ttnn.add(infer_fea, original_features)
126-
127- # return infer_fea
128-
12936 def __call__ (self , fea ):
13037 B , N , C = fea .shape
13138
0 commit comments