12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import Optional
15
+ from typing import Optional , Union
16
16
17
17
import torch
18
18
from compressed_tensors .transform import TransformArgs , TransformScheme
@@ -41,6 +41,7 @@ class HadamardFactory(TransformFactory):
41
41
def __init__ (self , name : str , scheme : TransformScheme , seed : Optional [int ] = None ):
42
42
super ().__init__ (name , scheme , seed )
43
43
self .weights = ParameterizedDefaultDict (self ._create_weight )
44
+ self .perms = ParameterizedDefaultDict (self ._create_permutation )
44
45
45
46
def create_transform (self , module : Module , args : TransformArgs ):
46
47
"""
@@ -56,24 +57,35 @@ def create_transform(self, module: Module, args: TransformArgs):
56
57
device = get_offloaded_device (module )
57
58
58
59
weight = self .weights [size , dtype , device ]
59
- return HadamardTransform (weight , args )
60
+ perm = self .perms [weight ] if self .scheme .randomize else None
61
+ return HadamardTransform (weight , perm , args )
60
62
61
63
def _create_weight (self , size : int , dtype : dtype , device : device ) -> Parameter :
62
64
data = deterministic_hadamard_matrix (size , dtype , device )
63
65
data = data .to (dtype = dtype , device = device )
64
66
return Parameter (data , requires_grad = self .scheme .requires_grad )
65
67
68
+ def _create_permutation (self , weight : Parameter ) -> Parameter :
69
+ data = torch .randperm (weight .size (0 ), generator = self .generator )
70
+ return Parameter (data , requires_grad = False )
71
+
66
72
67
73
class HadamardTransform (TransformBase ):
68
- def __init__ (self , weight : Parameter , args : TransformArgs ):
74
+ def __init__ (
75
+ self , weight : Parameter , perm : Union [Parameter , None ], args : TransformArgs
76
+ ):
69
77
super ().__init__ ()
70
78
self .weight = weight
79
+ self .perm = perm
71
80
self .args = args
72
81
73
82
def forward (self , value : Tensor ) -> Tensor :
74
- if not self .args .inverse :
75
- weight = self .weight
76
- else :
77
- weight = self .weight .T
83
+ weight = self .weight
84
+
85
+ if self .perm is not None :
86
+ weight = weight [self .perm ][:, self .perm ]
87
+
88
+ if self .args .inverse :
89
+ weight = weight .T
78
90
79
91
return apply_transform_weight (weight , value , self .args .location )
0 commit comments