1+ import torch
2+ import timm
3+ import numpy as np
4+
5+ from einops import repeat , rearrange
6+ from einops .layers .torch import Rearrange
7+
8+
9+ # 这里可以用两个timm模型进行构建我们的结果
10+ from timm .models .layers import trunc_normal_
11+ from timm .models .vision_transformer import Block
12+
13+ def random_indexes (size : int ):
14+ forward_indexes = np .arange (size )
15+ np .random .shuffle (forward_indexes ) # 打乱index
16+ backward_indexes = np .argsort (forward_indexes ) # 得到原来index的位置,方便进行还原
17+ return forward_indexes , backward_indexes
18+
19+ def take_indexes (sequences , indexes ):
20+ return torch .gather (sequences , 0 , repeat (indexes , 't b -> t b c' , c = sequences .shape [- 1 ]))
21+
22+ class PatchShuffle (torch .nn .Module ):
23+ def __init__ (self , ratio ) -> None :
24+ super ().__init__ ()
25+ self .ratio = ratio
26+
27+ def forward (self , patches : torch .Tensor ):
28+ T , B , C = patches .shape # length, batch, dim
29+ remain_T = int (T * (1 - self .ratio ))
30+
31+ indexes = [random_indexes (T ) for _ in range (B )]
32+ forward_indexes = torch .as_tensor (np .stack ([i [0 ] for i in indexes ], axis = - 1 ), dtype = torch .long ).to (patches .device )
33+ backward_indexes = torch .as_tensor (np .stack ([i [1 ] for i in indexes ], axis = - 1 ), dtype = torch .long ).to (patches .device )
34+
35+ patches = take_indexes (patches , forward_indexes ) # 随机打乱了数据的patch,这样所有的patch都被打乱了
36+ patches = patches [:remain_T ] #得到未mask的pacth [T*0.25, B, C]
37+
38+ return patches , forward_indexes , backward_indexes
39+
40+ class MAE_Encoder (torch .nn .Module ):
41+ def __init__ (self ,
42+ image_size = 32 ,
43+ patch_size = 2 ,
44+ emb_dim = 192 ,
45+ num_layer = 12 ,
46+ num_head = 3 ,
47+ mask_ratio = 0.75 ,
48+ ) -> None :
49+ super ().__init__ ()
50+
51+ self .cls_token = torch .nn .Parameter (torch .zeros (1 , 1 , emb_dim ))
52+ self .pos_embedding = torch .nn .Parameter (torch .zeros ((image_size // patch_size ) ** 2 , 1 , emb_dim ))
53+
54+ # 对patch进行shuffle 和 mask
55+ self .shuffle = PatchShuffle (mask_ratio )
56+
57+ # 这里得到一个 (3, dim, patch, patch)
58+ self .patchify = torch .nn .Conv2d (3 , emb_dim , patch_size , patch_size )
59+
60+ self .transformer = torch .nn .Sequential (* [Block (emb_dim , num_head ) for _ in range (num_layer )])
61+
62+ # ViT的laynorm
63+ self .layer_norm = torch .nn .LayerNorm (emb_dim )
64+
65+ self .init_weight ()
66+
67+ # 初始化类别编码和向量编码
68+ def init_weight (self ):
69+ trunc_normal_ (self .cls_token , std = .02 )
70+ trunc_normal_ (self .pos_embedding , std = .02 )
71+
72+ def forward (self , img ):
73+ patches = self .patchify (img )
74+ patches = rearrange (patches , 'b c h w -> (h w) b c' )
75+ patches = patches + self .pos_embedding
76+
77+ patches , forward_indexes , backward_indexes = self .shuffle (patches )
78+
79+ patches = torch .cat ([self .cls_token .expand (- 1 , patches .shape [1 ], - 1 ), patches ], dim = 0 )
80+ patches = rearrange (patches , 't b c -> b t c' )
81+ features = self .layer_norm (self .transformer (patches ))
82+ features = rearrange (features , 'b t c -> t b c' )
83+
84+ return features , backward_indexes
85+
86+ class MAE_Decoder (torch .nn .Module ):
87+ def __init__ (self ,
88+ image_size = 32 ,
89+ patch_size = 2 ,
90+ emb_dim = 192 ,
91+ num_layer = 4 ,
92+ num_head = 3 ,
93+ ) -> None :
94+ super ().__init__ ()
95+
96+ self .mask_token = torch .nn .Parameter (torch .zeros (1 , 1 , emb_dim ))
97+ self .pos_embedding = torch .nn .Parameter (torch .zeros ((image_size // patch_size ) ** 2 + 1 , 1 , emb_dim ))
98+
99+ self .transformer = torch .nn .Sequential (* [Block (emb_dim , num_head ) for _ in range (num_layer )])
100+
101+ self .head = torch .nn .Linear (emb_dim , 3 * patch_size ** 2 )
102+ self .patch2img = Rearrange ('(h w) b (c p1 p2) -> b c (h p1) (w p2)' , p1 = patch_size , p2 = patch_size , h = image_size // patch_size )
103+
104+ self .init_weight ()
105+
106+ def init_weight (self ):
107+ trunc_normal_ (self .mask_token , std = .02 )
108+ trunc_normal_ (self .pos_embedding , std = .02 )
109+
110+ def forward (self , features , backward_indexes ):
111+ T = features .shape [0 ]
112+ backward_indexes = torch .cat ([torch .zeros (1 , backward_indexes .shape [1 ]).to (backward_indexes ), backward_indexes + 1 ], dim = 0 )
113+ features = torch .cat ([features , self .mask_token .expand (backward_indexes .shape [0 ] - features .shape [0 ], features .shape [1 ], - 1 )], dim = 0 )
114+ features = take_indexes (features , backward_indexes )
115+ features = features + self .pos_embedding # 加上了位置编码的信息
116+
117+ features = rearrange (features , 't b c -> b t c' )
118+ features = self .transformer (features )
119+ features = rearrange (features , 'b t c -> t b c' )
120+ features = features [1 :] # remove global feature 去掉全局信息,得到图像信息
121+
122+ patches = self .head (features ) # 用head得到patchs
123+ mask = torch .zeros_like (patches )
124+ mask [T :] = 1 # mask其他的像素全部设为 1
125+ mask = take_indexes (mask , backward_indexes [1 :] - 1 )
126+ img = self .patch2img (patches ) # 得到 重构之后的 img
127+ mask = self .patch2img (mask )
128+
129+ return img , mask
130+
131+ class MAE_ViT (torch .nn .Module ):
132+ def __init__ (self ,
133+ image_size = 32 ,
134+ patch_size = 2 ,
135+ emb_dim = 192 ,
136+ encoder_layer = 12 ,
137+ encoder_head = 3 ,
138+ decoder_layer = 4 ,
139+ decoder_head = 3 ,
140+ mask_ratio = 0.75 ,
141+ ) -> None :
142+ super ().__init__ ()
143+
144+ self .encoder = MAE_Encoder (image_size , patch_size , emb_dim , encoder_layer , encoder_head , mask_ratio )
145+ self .decoder = MAE_Decoder (image_size , patch_size , emb_dim , decoder_layer , decoder_head )
146+
147+ def forward (self , img ):
148+ features , backward_indexes = self .encoder (img )
149+ predicted_img , mask = self .decoder (features , backward_indexes )
150+ return predicted_img , mask
151+
152+ class ViT_Classifier (torch .nn .Module ):
153+ def __init__ (self , encoder : MAE_Encoder , num_classes = 10 ) -> None :
154+ super ().__init__ ()
155+ self .cls_token = encoder .cls_token
156+ self .pos_embedding = encoder .pos_embedding
157+ self .patchify = encoder .patchify
158+ self .transformer = encoder .transformer
159+ self .layer_norm = encoder .layer_norm
160+ self .head = torch .nn .Linear (self .pos_embedding .shape [- 1 ], num_classes )
161+
162+ def forward (self , img ):
163+ patches = self .patchify (img )
164+ patches = rearrange (patches , 'b c h w -> (h w) b c' )
165+ patches = patches + self .pos_embedding
166+ patches = torch .cat ([self .cls_token .expand (- 1 , patches .shape [1 ], - 1 ), patches ], dim = 0 )
167+ patches = rearrange (patches , 't b c -> b t c' )
168+ features = self .layer_norm (self .transformer (patches ))
169+ features = rearrange (features , 'b t c -> t b c' )
170+ logits = self .head (features [0 ])
171+ return logits
172+
173+
174+ if __name__ == '__main__' :
175+ shuffle = PatchShuffle (0.75 )
176+ a = torch .rand (16 , 2 , 10 )
177+ b , forward_indexes , backward_indexes = shuffle (a )
178+ print (b .shape )
179+
180+ img = torch .rand (2 , 3 , 32 , 32 )
181+ encoder = MAE_Encoder ()
182+ decoder = MAE_Decoder ()
183+ features , backward_indexes = encoder (img )
184+ print (forward_indexes .shape )
185+ predicted_img , mask = decoder (features , backward_indexes )
186+ print (predicted_img .shape )
187+ loss = torch .mean ((predicted_img - img ) ** 2 * mask / 0.75 )
0 commit comments