1+ from copy import deepcopy
2+ from typing import Dict , Any
3+
4+ import torch as th
5+ import torch .nn as nn
6+ import timm
7+ import tensorly as tl
8+ tl .set_backend ("pytorch" )
9+
10+
11+ def cp_attn (self , x ):
12+ B , N , C = x .shape
13+ qkv = self .qkv (x )
14+ f1 = global_model .CP_A1 [self .attn_idx :self .attn_idx + 3 , :]
15+ tensor_attn = tl .cp_to_tensor ((global_model .CP_R1 , (f1 , global_model .CP_A2 , global_model .CP_A3 , global_model .CP_A4 )))
16+ K , E , H , D = tensor_attn .shape
17+ tensor_attn = tensor_attn .reshape ((K , E , H * D ))
18+ qkv_delta = th .einsum ("bnd, kde->kbne" , x , self .dp (tensor_attn ))
19+ qkv_delta = qkv_delta .reshape (3 , B , N , self .num_heads , C // self .num_heads ).permute (
20+ 0 , 1 , 3 , 2 , 4
21+ )
22+ qkv = qkv .reshape (B , N , 3 , self .num_heads , C // self .num_heads ).permute (
23+ 2 , 0 , 3 , 1 , 4
24+ )
25+ qkv += qkv_delta * self .s
26+ q , k , v = qkv [0 ], qkv [1 ], qkv [2 ]
27+ attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
28+ attn = attn .softmax (dim = - 1 )
29+ attn = self .attn_drop (attn )
30+
31+ x = (attn @v ).transpose (1 , 2 ).reshape (B , N , C )
32+
33+ proj = self .proj (x )
34+ p1 = global_model .CP_P1 [self .idx :self .idx + 1 , :]
35+ tensor_proj = tl .cp_to_tensor ((global_model .CP_R2 , (p1 , global_model .CP_P2 , global_model .CP_P3 )))
36+ AA , AB , AC = tensor_proj .shape
37+ tensor_proj = tensor_proj .reshape ((AA * AB , AC ))
38+ proj_delta = x @self .dp (tensor_proj .T ) + global_model .CP_bias1
39+ proj += proj_delta * self .s
40+ x = self .proj_drop (proj )
41+ return x
42+
43+
44+ def cp_mlp (self , x ):
45+ p1_up = global_model .CP_P1 [self .idx :self .idx + 4 , :]
46+ p1_down = global_model .CP_P1 [self .idx + 4 : self .idx + 8 , :]
47+
48+ up = self .fc1 (x )
49+ tensor_up = tl .cp_to_tensor ((global_model .CP_R2 , (p1_up , global_model .CP_P2 , global_model .CP_P3 )))
50+ AA , AB , AC = tensor_up .shape
51+ tensor_up = tensor_up .reshape ((AA * AB , AC ))
52+ up_delta = x @self .dp (tensor_up .T ) + global_model .CP_bias2
53+ up += up_delta * self .s
54+
55+ x = self .act (up )
56+ x = self .drop (x )
57+
58+ down = self .fc2 (x )
59+ tensor_down = tl .cp_to_tensor ((global_model .CP_R2 , (p1_down , global_model .CP_P2 , global_model .CP_P3 )))
60+ tensor_down = tensor_down .reshape ((AA * AB , AC ))
61+ down_delta = x @self .dp (tensor_down ) + global_model .CP_bias3
62+ down += down_delta * self .s
63+ x = self .drop (down )
64+ return x
65+
66+
67+ def set_cara (model : nn .Module , rank : int , scale : float , l_mu : float , l_std : float ):
68+ if type (model ) == timm .models .vision_transformer .VisionTransformer :
69+ # Declare CaRA parameters
70+ model .CP_A1 = nn .Parameter (th .empty ([36 , rank ]), requires_grad = True )
71+ model .CP_A2 = nn .Parameter (th .empty ([768 , rank ]), requires_grad = True )
72+ model .CP_A3 = nn .Parameter (th .empty ([12 , rank ]), requires_grad = True )
73+ model .CP_A4 = nn .Parameter (th .empty ([768 // 12 , rank ]), requires_grad = True )
74+ model .CP_P1 = nn .Parameter (th .empty ([108 , rank ]), requires_grad = True )
75+ model .CP_P2 = nn .Parameter (th .empty ([768 , rank ]), requires_grad = True )
76+ model .CP_P3 = nn .Parameter (th .empty ([768 , rank ]), requires_grad = True )
77+ model .CP_R1 = nn .Parameter (th .empty ([rank ]), requires_grad = True )
78+ model .CP_R2 = nn .Parameter (th .empty ([rank ]), requires_grad = True )
79+ model .CP_bias1 = nn .Parameter (th .empty ([768 ]), requires_grad = True )
80+ model .CP_bias2 = nn .Parameter (th .empty ([768 * 4 ]), requires_grad = True )
81+ model .CP_bias3 = nn .Parameter (th .empty ([768 ]), requires_grad = True )
82+ # Initialise CaRA parameters
83+ nn .init .xavier_normal_ (model .CP_A1 )
84+ nn .init .zeros_ (model .CP_A2 )
85+ nn .init .orthogonal_ (model .CP_A3 )
86+ nn .init .orthogonal_ (model .CP_A4 )
87+ nn .init .xavier_normal_ (model .CP_P1 )
88+ nn .init .zeros_ (model .CP_P2 )
89+ nn .init .orthogonal_ (model .CP_P3 )
90+ if l_std != 0.0 :
91+ nn .init .normal_ (model .CP_R1 , mean = l_mu , std = l_std )
92+ nn .init .normal_ (model .CP_R2 , mean = l_mu , std = l_std )
93+ elif l_mu == 1.0 and l_std == 0.0 :
94+ nn .init .ones_ (model .CP_R1 )
95+ nn .init .ones_ (model .CP_R2 )
96+ nn .init .zeros_ (model .CP_bias1 )
97+ nn .init .zeros_ (model .CP_bias2 )
98+ nn .init .zeros_ (model .CP_bias3 )
99+ # CaRA indexing
100+ model .idx = 0
101+ model .attn_idx = 0
102+ for child in model .children ():
103+ if type (child ) == timm .models .vision_transformer .Attention :
104+ child .dp = nn .Dropout (0.1 )
105+ child .s = scale
106+ child .dim = rank
107+ child .idx = global_model .idx
108+ child .attn_idx = global_model .attn_idx
109+ global_model .idx += 1
110+ global_model .attn_idx += 3
111+ bound_method = cp_attn .__get__ (child , child .__class__ )
112+ setattr (child , "forward" , bound_method )
113+ elif type (child ) == timm .models .layers .mlp .Mlp :
114+ child .dp = nn .Dropout (0.1 )
115+ child .s = scale
116+ child .dim = rank
117+ child .idx = global_model .idx
118+ global_model .idx += 8
119+ bound_method = cp_mlp .__get__ (child , child .__class__ )
120+ setattr (child , "forward" , bound_method )
121+ elif len (list (child .children ())) != 0 :
122+ set_cara (child , rank , scale , l_mu , l_std )
123+
124+
125+ def cara (config ):
126+ # CaRA parameters
127+ model = config ["model" ]
128+ rank = config ["rank" ]
129+ scale = config ["scale" ]
130+ l_mu = config ["l_mu" ]
131+ l_std = config ["l_std" ]
132+
133+ global global_model
134+ global_model = model
135+ set_cara (model , rank , scale , l_mu , l_std )
136+ return global_model
0 commit comments