Skip to content

Commit 13ebd90

Browse files
rot impl
Signed-off-by: cliu-us <[email protected]>
1 parent 4705c75 commit 13ebd90

File tree

4 files changed

+275
-0
lines changed

4 files changed

+275
-0
lines changed

fms_mo/quant/quantizers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
import torch.nn as nn # pylint: disable=consider-using-from-import
4141
import torch.nn.functional as F
4242

43+
# Local
44+
from fms_mo.quant.rotation import RotQuantWrapper
45+
4346
logger = logging.getLogger(__name__)
4447

4548

@@ -66,8 +69,16 @@ def get_activation_quantizer(
6669
- pact/pact+/pactsym
6770
- sawb/sawb+
6871
- max
72+
73+
If qa_mode has "rot_" prefix or "_rot" suffix, wrap it with RotQuantizer(), remember to set up
74+
R_left, R_right tensors later.
6975
"""
7076

77+
use_rot = False
78+
if "rot_" in qa_mode or "_rot" in qa_mode:
79+
use_rot = True
80+
qa_mode.replace("rot_", "").replace("_rot", "")
81+
7182
if not use_swcap:
7283
QPACTLUT = {
7384
"pact_uni": PACT,
@@ -220,6 +231,9 @@ def get_activation_quantizer(
220231
f"activation quantization mode {qa_mode} is incompatible with swcap"
221232
)
222233

234+
if use_rot:
235+
act_quantizer = RotQuantWrapper(act_quantizer)
236+
223237
return act_quantizer
224238

225239

@@ -245,7 +259,15 @@ def get_weight_quantizer(
245259
SWCAP quantizers:
246260
- sawb/sawb+
247261
- max
262+
If qa_mode has "rot_" prefix or "_rot" suffix, wrap it with RotQuantizer(), remember to set up
263+
R_left, R_right tensors later.
248264
"""
265+
266+
use_rot = False
267+
if "rot_" in qw_mode or "_rot" in qw_mode:
268+
use_rot = True
269+
qw_mode.replace("rot_", "").replace("_rot", "")
270+
249271
weight_quantizer = None
250272
if not use_swcap:
251273
cggrad = "cgpact" in qw_mode
@@ -367,6 +389,9 @@ def get_weight_quantizer(
367389
f"activation quantized mode {qw_mode} is incompatible with swcap"
368390
)
369391

392+
if use_rot:
393+
weight_quantizer = RotQuantWrapper(weight_quantizer)
394+
370395
return weight_quantizer
371396

372397

fms_mo/quant/rotation.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Util functions related to Hadamard rotation."""
16+
17+
# Third Party
18+
import torch
19+
20+
# Local
21+
from fms_mo.utils.hadamard_util import matmul_hadU_cuda
22+
23+
24+
class RotQuantWrapper(torch.nn.Module):
25+
"""Add a wrapper to fms-mo quantizers. Objects of this class could have two rotation tensors,
26+
and basic formula is:
27+
28+
self.quantizer(self.rot_left @ input_tensor @ self.rot_right)
29+
30+
NOTE rot_xxx could be optional, depending on whether it's for weights or activations.
31+
For example, in SpinQuant QKV Linears will looks like (pseudo-code, "self" are not refering
32+
to the same objects here):
33+
qx = self.quantize_feature(x) # no rotation, just a normal quantizer
34+
qw_q = self.quantize_weight(self.weight, R1_t) # need left rotation only
35+
qw_k = self.quantize_weight(sefl.weight, R1_t)
36+
qw_v = self.quantize_weight(sefl.weight, R1_t, R2) # need both left and right rotation
37+
38+
return F.linear(qx, qw, bias)
39+
40+
for MLP down_proj
41+
qx = self.quantize_feature(x, None, R4) # for activation, should be x @ R
42+
qw = self.quantize_weight(sefl.weight, R4_t, R1)
43+
44+
return F.linear(qx, qw, bias)
45+
46+
Also need to make sure self.R is pointing to a nn.Parameter() if training on R is needed.
47+
"""
48+
49+
def __init__(self, quantizer, *args, **kwargs):
50+
self.online_full_had = kwargs.pop("online_full_had", None)
51+
self.f32_had = kwargs.pop("f32_had", None)
52+
super().__init__(*args, **kwargs)
53+
self.quantizer = quantizer
54+
self.R_left = None
55+
self.R_right = None
56+
self.K_left = None # if K_xxx > 1, R_xxx is a special had matrix
57+
self.K_right = None
58+
59+
def forward(self, input_tensor):
60+
org_dtype = input_tensor.dtype
61+
62+
if self.online_full_had:
63+
# online hadamard => rotation for activation. should be input_tensor @ R_right
64+
# cannot be fused into W and no training, either.
65+
if self.fp32_had:
66+
input_tensor = input_tensor.float()
67+
input_tensor = matmul_hadU_cuda(
68+
input_tensor, self.R_right, self.K_right
69+
).to(org_dtype)
70+
71+
return input_tensor
72+
73+
# not online => rotation for weights, could be fused into W later.
74+
if self.R_left:
75+
input_tensor = self.R_left @ inp_tensor
76+
if self.R_right:
77+
inp_tensor = inp_tensor @ self.R_right
78+
79+
return inp_tensor

fms_mo/utils/hadamard_util.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot).
16+
# Licensed under Apache License 2.0.
17+
# Adapted from https://github.com/Cornell-RelaxML/quip-sharp/blob/main/lib/utils/matmul_had.py
18+
# and https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py
19+
"""
20+
Change original "text tensor implementation" into binaries for better efficiency. Only has 12
21+
sizes available in the safetensors file. [12, 20, 28, 36, 40, 44, 52, 60, 108, 140, 156, 172]
22+
"""
23+
24+
# Third Party
25+
from fast_hadamard_transform import hadamard_transform
26+
from safetensors import safe_open
27+
import torch
28+
29+
30+
class HadamardTransform(torch.autograd.Function):
31+
"""The unnormalized Hadamard transform (i.e. without dividing by sqrt(2))"""
32+
33+
# TODO seems redundant, insdie hadamard_transform(), backward is already handled...?
34+
@staticmethod
35+
def forward(ctx, u):
36+
return hadamard_transform(u)
37+
38+
@staticmethod
39+
def backward(ctx, grad):
40+
return hadamard_transform(grad)
41+
42+
43+
def get_hadK(n, transpose=False):
44+
"""Simplify the implementation and use binary tensors instead of text implementation."""
45+
for K in [172, 156, 140, 108, 60, 52, 44, 40, 36, 28, 20, 12]:
46+
if n % K == 0 and is_pow2(n // K):
47+
with safe_open("hadk.safetensors", framework="pt") as f:
48+
assert (
49+
str(K) in f.keys()
50+
), f"Special size Hadamard {K} does not exist in the file."
51+
hadK = f.get_tensor(str(K))
52+
53+
if transpose:
54+
hadK = hadK.T
55+
56+
break
57+
58+
if hadK is None:
59+
if is_pow2(n):
60+
K = 1
61+
else:
62+
raise RuntimeError(
63+
f"{n} is not power of 2 or does not have a special size Hadamard available."
64+
)
65+
66+
return hadK, K
67+
68+
69+
def matmul_hadU(X, transpose=False):
70+
n = X.shape[-1]
71+
hadK, K = get_hadK(n, transpose)
72+
input = X.clone().view(-1, n, 1)
73+
output = input.clone()
74+
while input.shape[1] > K:
75+
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
76+
output = output.view(input.shape)
77+
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
78+
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
79+
output = output.view(input.shape[0], input.shape[1], -1)
80+
(input, output) = (output, input)
81+
del output
82+
83+
if K > 1:
84+
# Do not explicitly repeat - OOM
85+
# input = torch.bmm(
86+
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
87+
# Use bcast instead
88+
input = hadK.view(1, K, K).to(input) @ input
89+
90+
return input.view(X.shape) / torch.tensor(n).sqrt()
91+
92+
93+
def matmul_hadUt(X):
94+
return matmul_hadU(X, transpose=True)
95+
96+
97+
def random_hadamard_matrix(size, device):
98+
# See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation"
99+
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
100+
Q = Q * 2 - 1
101+
Q = torch.diag(Q)
102+
return matmul_hadU(Q).to(device)
103+
104+
105+
def hadamard_matrix(size, device):
106+
# See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation"
107+
Q = torch.eye(size)
108+
return matmul_hadU(Q).to(device)
109+
110+
111+
def matmul_hadU_cuda(X, hadK, K):
112+
n = X.shape[-1]
113+
if K == 1:
114+
return HadamardTransform.apply(X.contiguous()) / torch.tensor(n).sqrt()
115+
# if transpose:
116+
# hadK = hadK.T.contiguous()
117+
input = X.view(-1, K, n // K)
118+
input = HadamardTransform.apply(input.contiguous()) / torch.tensor(n).sqrt()
119+
input = hadK.to(input.device).to(input.dtype) @ input
120+
return input.reshape(X.shape)
121+
122+
123+
def matmul_hadUt_cuda(X, hadK, K):
124+
return matmul_hadU_cuda(X, hadK, K, transpose=True)
125+
126+
127+
def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
128+
assert isinstance(module, torch.nn.Linear)
129+
in_features, out_features = module.in_features, module.out_features
130+
131+
if had_dim != -1:
132+
assert is_pow2(had_dim), "Hadamard dimension must be a power of 2!"
133+
134+
W_ = module.weight.data
135+
dtype = W_.dtype
136+
dev = W_.device
137+
init_shape = W_.shape
138+
W_ = W_.float().cuda()
139+
140+
if had_dim == -1:
141+
if output:
142+
had_K, K = get_hadK(out_features)
143+
W_ = matmul_hadU_cuda(W_.t(), had_K, K).t()
144+
if not output:
145+
had_K, K = get_hadK(in_features)
146+
W_ = matmul_hadU_cuda(W_, had_K, K)
147+
else:
148+
hadK = hadamard_matrix(had_dim, "cuda").to(torch.float64)
149+
if R2 is not None:
150+
hadK = R2.to(torch.float64)
151+
if output:
152+
W_ = W_.t()
153+
transposed_shape = W_.shape
154+
temp = W_.reshape(-1, transposed_shape[-1] // had_dim, had_dim)
155+
temp = temp.to(torch.float64) @ hadK
156+
W_ = temp.reshape(transposed_shape).t()
157+
else:
158+
init_shape = W_.shape
159+
temp = W_.reshape(-1, init_shape[-1] // had_dim, had_dim)
160+
temp = temp.to(torch.float64) @ hadK
161+
W_ = temp.reshape(init_shape)
162+
module.weight.data = W_.to(device=dev, dtype=dtype)
163+
164+
165+
def is_pow2(n):
166+
return (n & (n - 1) == 0) and (n > 0)
167+
168+
169+
# hadamard matrices for had12, had36.pal2, had52,will,
170+
# # had60.pal, had108.pal, had140.pal, had156.will, had172.will:
171+
# http://www.neilsloane.com/hadamard/index.html

fms_mo/utils/hadk.safetensors

382 KB
Binary file not shown.

0 commit comments

Comments
 (0)