Skip to content

Commit 2d78feb

Browse files
Merge branch 'main' into build
Signed-off-by: tharapalanivel <[email protected]>
2 parents 9c70870 + 16fc615 commit 2d78feb

File tree

17 files changed

+1330
-49
lines changed

17 files changed

+1330
-49
lines changed

.gitignore

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,10 @@ venv/
3838
dictionary.dic
3939

4040
# Generated error log
41-
error.log
41+
error.log
42+
43+
# Files generated from running examples
44+
fms_mo.log
45+
data_train/
46+
data_test/
47+
act_scales/

fms_mo/aiu_addons/gptq/__init__.py

Whitespace-only changes.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Implement FMS adapter for GPTQ W4A16 checkpoints"""
15+
16+
# Standard
17+
from typing import Mapping
18+
19+
# Third Party
20+
from fms.utils import serialization
21+
import torch
22+
23+
24+
def _gptq_qweights_transpose_aiu(
25+
input_sd: Mapping[str, torch.Tensor],
26+
) -> Mapping[str, torch.Tensor]:
27+
new_sd = {}
28+
for name, param in input_sd.items():
29+
new_sd[name] = param
30+
# for AIU, qweights are needed as [out_feat, in_feat]
31+
if "qweight" in name:
32+
new_sd[name] = new_sd[name].t()
33+
elif "g_idx" in name:
34+
new_sd[name] = torch.zeros(1, dtype=torch.int32, device=param.device)
35+
return new_sd
36+
37+
38+
serialization.register_adapter_step(
39+
"llama", "gptq_qweights_transpose_aiu", _gptq_qweights_transpose_aiu
40+
)
41+
serialization.register_adapter_step(
42+
"gpt_bigcode", "gptq_qweights_transpose_aiu", _gptq_qweights_transpose_aiu
43+
)
44+
serialization.register_adapter(
45+
"llama",
46+
"hf_gptq_aiu",
47+
[
48+
"hf_to_fms_names",
49+
"hf_to_fms_rope",
50+
"hf_gptq_fusion_check",
51+
"weight_fusion",
52+
"gptq_qweights_transpose_aiu",
53+
],
54+
)
55+
serialization.register_adapter(
56+
"gpt_bigcode",
57+
"hf_gptq_aiu",
58+
["hf_to_fms_names", "weight_fusion", "gptq_qweights_transpose_aiu"],
59+
)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Implement GPTQ W4A16 linear module compatible with AIU compiler"""
15+
16+
# Standard
17+
from typing import Any, Mapping, Optional
18+
import math
19+
20+
# Third Party
21+
from fms.modules.linear import (
22+
LinearModuleShardingInfo,
23+
LinearParameterShardingInfo,
24+
register_linear_type_to_module_map,
25+
register_linear_type_to_sharding_map,
26+
shard_base_linear,
27+
)
28+
from fms.modules.tp import ShardType, TPModule
29+
from fms.utils.gptq import GPTQLinearConfig
30+
import torch
31+
import torch.nn as nn
32+
33+
# Local
34+
from fms_mo.aiu_addons.gptq.gptq_aiu_op import register_aiu_gptq_op
35+
36+
register_aiu_gptq_op()
37+
38+
39+
class GPTQLinearAIU(nn.Module):
40+
def __init__(
41+
self,
42+
in_features: int,
43+
out_features: int,
44+
bias: bool,
45+
config: GPTQLinearConfig,
46+
):
47+
super().__init__()
48+
49+
self.in_features = in_features
50+
self.out_features = out_features
51+
self.bits = config.bits
52+
self.group_size = config.group_size if config.group_size != -1 else in_features
53+
self.desc_act = config.desc_act
54+
# self.weight_transposed = True
55+
56+
if self.bits not in [4]:
57+
raise NotImplementedError(
58+
"AIU GPTQLinear only supports 4 bits quantization."
59+
)
60+
if in_features % self.group_size != 0:
61+
raise ValueError("`in_features` must be divisible by `group_size`.")
62+
if in_features % 32 or out_features % 32:
63+
raise ValueError("`in_features` and `out_features` must be divisible by 32")
64+
if self.desc_act:
65+
raise NotImplementedError(
66+
"AIU GPTQLinear does not support activation reordering (`desc_act`)"
67+
)
68+
69+
# Register quantization parameters
70+
self.register_buffer(
71+
"qweight",
72+
torch.zeros(
73+
# transposed w.r.t. GPTQ ckpt (AIU requirement)
74+
(out_features, in_features // 32 * self.bits),
75+
dtype=torch.int32,
76+
),
77+
)
78+
self.register_buffer(
79+
"qzeros",
80+
torch.zeros(
81+
(
82+
math.ceil(in_features / self.group_size),
83+
out_features // 32 * self.bits,
84+
),
85+
dtype=torch.int32,
86+
),
87+
)
88+
self.register_buffer(
89+
"scales",
90+
torch.zeros(
91+
(math.ceil(in_features / self.group_size), out_features),
92+
dtype=torch.float16,
93+
),
94+
)
95+
# AIU requirement
96+
self.register_buffer("g_idx", torch.tensor([0], dtype=torch.int32))
97+
if bias:
98+
self.register_buffer(
99+
"bias",
100+
torch.zeros((out_features), dtype=torch.float16),
101+
)
102+
else:
103+
self.bias = None
104+
105+
# Register op
106+
if not hasattr(torch.ops, "gptq_gemm") or not hasattr(
107+
torch.ops.gptq_gemm, "i4f16_fxinputs_aiu"
108+
):
109+
raise ValueError(
110+
"Custom AIU op `gptq_gemm.i4f16_fxinputs_aiu` has not been registered."
111+
)
112+
self.aiu_op = torch.ops.gptq_gemm.i4f16_fxinputs_aiu
113+
114+
def forward(self, x):
115+
x = self.aiu_op(
116+
x.half(),
117+
self.qweight,
118+
self.qzeros,
119+
self.scales,
120+
self.g_idx,
121+
)
122+
if self.bias is not None:
123+
x.add_(self.bias)
124+
return x
125+
126+
def __repr__(self) -> str:
127+
return (
128+
f"{self.__class__.__name__}"
129+
f"(in={self.in_features}, out={self.out_features}, "
130+
f"bias={self.bias is not None}, group={self.group_size}, "
131+
f"op={self.aiu_op})"
132+
)
133+
134+
135+
def get_gptq_aiu_linear(
136+
in_features: int,
137+
out_features: int,
138+
bias: bool,
139+
linear_config: Optional[Mapping[str, Any]] = None,
140+
):
141+
gptq_config = GPTQLinearConfig(**linear_config)
142+
if gptq_config.desc_act:
143+
raise NotImplementedError(
144+
"Activation reordering (desc_act=True) not supported on AIU"
145+
)
146+
linear = GPTQLinearAIU(
147+
in_features=in_features,
148+
out_features=out_features,
149+
bias=bias,
150+
config=gptq_config,
151+
)
152+
setattr(linear, "desc_act", gptq_config.desc_act)
153+
return linear
154+
155+
156+
def shard_gptq_aiu_linear(
157+
tensor_values: dict[str, torch.Tensor],
158+
tp_module: TPModule,
159+
module_sharding_info: dict[str, LinearModuleShardingInfo],
160+
) -> Optional[set]:
161+
"""
162+
Set up GPTQ quantization parameters to be sharded onto
163+
AIU-compliant linear modules
164+
165+
| GPU |
166+
sharding | qparam | shard | dim |
167+
----------+----------+-------+-----|
168+
colwise | qweight | Y | 0 |
169+
| bias | Y | 0 |
170+
| scales | Y | 1 |
171+
| qzeros | Y | 1 |
172+
| g_idx | N | - |
173+
----------+----------+-------+-----|
174+
rowwise | qweight | Y | 1 |
175+
| bias | 0 | - |
176+
| scales | Y | 0 |
177+
| qzeros | Y | 0 |
178+
| g_idx | N | - |
179+
"""
180+
param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {}
181+
for module_name, module_info in module_sharding_info.items():
182+
gptq_aiu_mod = module_info.linear_module
183+
params: dict[str, LinearParameterShardingInfo] = {
184+
"qweight": LinearParameterShardingInfo(
185+
module_info.sharding_dim, ShardType.SHARD
186+
),
187+
"scales": LinearParameterShardingInfo(
188+
1 - module_info.sharding_dim, ShardType.SHARD
189+
),
190+
"qzeros": LinearParameterShardingInfo(
191+
1 - module_info.sharding_dim, ShardType.SHARD
192+
),
193+
# g_idx on aiu is 1-dim zero tensor, always cloned on each shard
194+
"g_idx": LinearParameterShardingInfo(0, ShardType.CLONE),
195+
}
196+
if gptq_aiu_mod.bias is not None:
197+
params["bias"] = LinearParameterShardingInfo(
198+
module_info.sharding_dim,
199+
ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0,
200+
)
201+
param_sharding_info[module_name] = params
202+
203+
unused_keys = shard_base_linear(
204+
tensor_values, tp_module, module_sharding_info, param_sharding_info
205+
)
206+
return unused_keys
207+
208+
209+
register_linear_type_to_module_map("gptq_aiu", get_gptq_aiu_linear)
210+
register_linear_type_to_sharding_map("gptq_aiu", shard_gptq_aiu_linear)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Registration of GPTQ W4A16 node compatible with AIU compiler"""
15+
16+
# Standard
17+
import logging
18+
19+
# Third Party
20+
import torch
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
def register_aiu_gptq_op():
26+
"""Register AIU-specific op to enable torch compile without graph break.
27+
The op preserves I/O shapes of a `X @ W^T` matmul but performs no operation.
28+
Quantization parameters are taken as arguments, so that they end up attached to
29+
the computational graph.
30+
"""
31+
if hasattr(torch.ops, "gptq_gemm") and hasattr(
32+
torch.ops.gptq_gemm, "i4f16_fxinputs_aiu"
33+
):
34+
logger.warning("AIU op has already been registered")
35+
return
36+
37+
op_namespace_id = "gptq_gemm::i4f16_fxinputs_aiu"
38+
torch.library.define(
39+
op_namespace_id,
40+
"(Tensor x, Tensor qw, Tensor qzeros, Tensor scales, Tensor g_idx) -> Tensor",
41+
)
42+
43+
# Add implementations for the operator
44+
@torch.library.impl(op_namespace_id, "default")
45+
def i4f16_fxinputs_aiu(x, qw, qzeros, scales, g_idx):
46+
# on AIU, GPTQ qw is [out_feat, in_feat]
47+
outshape = x.shape[:-1] + (qw.shape[0],)
48+
x = x.view(-1, x.shape[-1])
49+
output = torch.zeros(
50+
(x.shape[0], qw.shape[0]),
51+
dtype=torch.float16,
52+
device=x.device,
53+
)
54+
return output.view(outshape)
55+
56+
@torch.library.impl_abstract(op_namespace_id)
57+
def i4f16_fxinputs_aiu_abstract(x, qw, qzeros, scales, g_idx):
58+
outshape = x.shape[:-1] + (qw.shape[0],)
59+
return torch.empty(
60+
outshape,
61+
dtype=torch.float16,
62+
device=x.device,
63+
requires_grad=False,
64+
)
65+
66+
logger.info("GPTQ op 'i4f16_fxinputs_aiu' has been registered")
67+
return

fms_mo/aiu_addons/i8i8/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)