Skip to content

Commit a31a4e5

Browse files
Merge pull request #53 from foundation-model-stack/adapter_fixes
feat: added granite support; fixed adapters to ignore model_config
2 parents 3f9c14e + b033c90 commit a31a4e5

File tree

8 files changed

+67
-26
lines changed

8 files changed

+67
-26
lines changed

fms_mo/aiu_addons/__init__.py

Whitespace-only changes.

fms_mo/aiu_addons/gptq/gptq_aiu_adapter.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
def _gptq_qweights_transpose_aiu(
2525
input_sd: Mapping[str, torch.Tensor],
26+
**kwargs, # pylint: disable=unused-argument
2627
) -> Mapping[str, torch.Tensor]:
2728
new_sd = {}
2829
for name, param in input_sd.items():
@@ -41,6 +42,9 @@ def _gptq_qweights_transpose_aiu(
4142
serialization.register_adapter_step(
4243
"gpt_bigcode", "gptq_qweights_transpose_aiu", _gptq_qweights_transpose_aiu
4344
)
45+
serialization.register_adapter_step(
46+
"granite", "gptq_qweights_transpose_aiu", _gptq_qweights_transpose_aiu
47+
)
4448
serialization.register_adapter(
4549
"llama",
4650
"hf_gptq_aiu",
@@ -57,3 +61,14 @@ def _gptq_qweights_transpose_aiu(
5761
"hf_gptq_aiu",
5862
["hf_to_fms_names", "weight_fusion", "gptq_qweights_transpose_aiu"],
5963
)
64+
serialization.register_adapter(
65+
"granite",
66+
"hf_gptq_aiu",
67+
[
68+
"hf_to_fms_names",
69+
"hf_to_fms_rope",
70+
"hf_gptq_fusion_check",
71+
"weight_fusion",
72+
"gptq_qweights_transpose_aiu",
73+
],
74+
)

fms_mo/aiu_addons/gptq/gptq_aiu_linear.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,18 @@
2828
from fms.modules.tp import ShardType, TPModule
2929
from fms.utils.gptq import GPTQLinearConfig
3030
import torch
31-
import torch.nn as nn
3231

3332
# Local
3433
from fms_mo.aiu_addons.gptq.gptq_aiu_op import register_aiu_gptq_op
3534

3635
register_aiu_gptq_op()
3736

3837

39-
class GPTQLinearAIU(nn.Module):
38+
class GPTQLinearAIU(torch.nn.Module):
39+
"""Simplified QLinear that wraps GPTQ W4A16 custom operation.
40+
gptq_gemm.i4f16_fxinputs_aiu must have been pre-registered to use this class.
41+
"""
42+
4043
def __init__(
4144
self,
4245
in_features: int,
@@ -112,6 +115,8 @@ def __init__(
112115
self.aiu_op = torch.ops.gptq_gemm.i4f16_fxinputs_aiu
113116

114117
def forward(self, x):
118+
"""Call pre-registered custom GPTQ operation"""
119+
115120
x = self.aiu_op(
116121
x.half(),
117122
self.qweight,
@@ -137,7 +142,9 @@ def get_gptq_aiu_linear(
137142
out_features: int,
138143
bias: bool,
139144
linear_config: Optional[Mapping[str, Any]] = None,
140-
):
145+
) -> torch.nn.Module:
146+
"""Retrieve a GPTQ W4A16 Linear module"""
147+
141148
gptq_config = GPTQLinearConfig(**linear_config)
142149
if gptq_config.desc_act:
143150
raise NotImplementedError(

fms_mo/aiu_addons/gptq/gptq_aiu_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
# Third Party
2020
import torch
2121

22+
# pylint: disable=unused-argument
23+
# gptq op must be registered with specific I/O, even if not in use by the op function
24+
2225
logger = logging.getLogger(__name__)
2326

2427

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
def _int8_qparams_aiu(
2525
input_sd: Mapping[str, torch.Tensor],
26+
**kwargs, # pylint: disable=unused-argument
2627
) -> Mapping[str, torch.Tensor]:
2728
new_sd = {}
2829
modules_seen = set()
@@ -94,7 +95,6 @@ def _add_defaults_and_concat(
9495
sq_scale.to(torch.float32),
9596
)
9697
)
97-
return
9898

9999

100100
# registration of new adapter steps for each architecture

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from fms.modules.tp import ShardType, TPModule
2929
from fms.utils.config import ModelConfig
3030
import torch
31-
import torch.nn as nn
3231

3332
# Local
3433
from fms_mo.aiu_addons.i8i8.i8i8_aiu_op import register_aiu_i8i8_op
@@ -38,6 +37,8 @@
3837

3938
@dataclass
4039
class W8A8LinearConfig(ModelConfig):
40+
"""Configuration for W8A8 Linear module"""
41+
4142
linear_type: str = "int8"
4243
bits: int = 8
4344
weight_per_channel: bool = False
@@ -46,8 +47,10 @@ class W8A8LinearConfig(ModelConfig):
4647
smoothquant_layers: Optional[list] = None
4748

4849

49-
class W8A8LinearAIU(nn.Module):
50-
"""Simplified QLinear that wraps quantize/dequantize operation"""
50+
class W8A8LinearAIU(torch.nn.Module):
51+
"""Simplified QLinear that wraps quantize/dequantize operation.
52+
fms_mo.i8i8_aiu must have been pre-registered to use this class.
53+
"""
5154

5255
def __init__(
5356
self,
@@ -199,7 +202,9 @@ def get_int8_aiu_linear(
199202
bias: bool,
200203
linear_config: Optional[Mapping[str, Any]] = None,
201204
use_smoothquant: bool = True,
202-
):
205+
) -> torch.nn.Module:
206+
"""Retrieve a W8A8 Linear module"""
207+
203208
int8_config = W8A8LinearConfig(**linear_config)
204209
linear = W8A8LinearAIU(
205210
in_features=in_features,
@@ -216,8 +221,7 @@ def shard_int8_aiu_linear(
216221
tp_module: TPModule,
217222
module_sharding_info: dict[str, LinearModuleShardingInfo],
218223
) -> Optional[set]:
219-
"""
220-
Set up INT8 (W8A8) quantization parameters to be sharded onto
224+
"""Set up INT8 (W8A8) quantization parameters to be sharded onto
221225
AIU-compliant linear modules
222226
223227
| GPU |
@@ -273,8 +277,7 @@ def shard_int8_aiu_linear(
273277
)
274278

275279
raise NotImplementedError("TP not yet supported for INT8. Work in progress")
276-
277-
return unused_keys
280+
# return unused_keys
278281

279282

280283
register_linear_type_to_module_map("int8_aiu", get_int8_aiu_linear)

fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222

2323
logger = logging.getLogger(__name__)
2424

25+
# pylint: disable=unused-argument
26+
# i8i8 op must be registered with specific I/O, even if not in use by the op function
27+
28+
# pylint: disable=not-callable
29+
# torch.nn.functional.linear not recognized as callable
30+
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
31+
2532

2633
def register_aiu_i8i8_op():
2734
"""Register AIU-specific op to enable torch compile without graph break.
@@ -64,7 +71,8 @@ def i8i8_aiu(
6471
dtype = x.dtype
6572
out_feat, in_feat = weight.size()
6673

67-
w_cv, w_cvn, a_cv, a_cvn, zshift, sq = extract_qdata(
74+
# unused returns are w_cvn and zero_shift
75+
w_cv, _, a_cv, a_cvn, _, sq = extract_qdata(
6876
qdata,
6977
weight_quant_type,
7078
activ_quant_type,
@@ -88,6 +96,8 @@ def i8i8_aiu_abstract(
8896
activ_quant_type,
8997
smoothquant,
9098
):
99+
"""OP template of I/O sizes"""
100+
91101
outshape = x.size()[:-1] + (weight.size(0),)
92102
return torch.empty(
93103
outshape, dtype=x.dtype, device=x.device, requires_grad=False
@@ -153,18 +163,19 @@ def dequant_weights(
153163
w_cv: torch.Tensor,
154164
sq: torch.Tensor,
155165
weight_quant_type: str,
156-
):
166+
) -> torch.Tensor:
167+
"""Dequantize integer weights based on quantizer type"""
168+
157169
if weight_quant_type == "per_tensor": # assume 8-bit symmetric W quantization
158170
# w size: (out_feat, in_feat)
159171
# sq size: (in_feat) or (1), no need to unsqueeze
160172
return (weight * w_cv / 127) / sq
161-
elif weight_quant_type == "per_channel":
173+
if weight_quant_type == "per_channel":
162174
# w_cv is (out_feat), need to unsqueeze to broadcast mul to weight
163175
return (weight * w_cv.unsqueeze(dim=1) / 127) / sq
164-
else:
165-
raise NotImplementedError(
166-
f"weight quantizantion type {weight_quant_type} is not supported"
167-
)
176+
raise NotImplementedError(
177+
f"weight quantizantion type {weight_quant_type} is not supported"
178+
)
168179

169180

170181
def quant_dequant_activ(
@@ -173,8 +184,10 @@ def quant_dequant_activ(
173184
a_cvn: torch.Tensor,
174185
sq: torch.Tensor,
175186
activ_quant_type: str,
176-
):
187+
) -> torch.Tensor:
177188
"""
189+
Quantize and dequantize activations based on quantizer type
190+
178191
x size (*, hid_dim)
179192
sq size (hid_dim) or (1)
180193
=> no need to unsqueeze to perform x / sq
@@ -183,18 +196,17 @@ def quant_dequant_activ(
183196
scale_x = 127 / a_cv
184197
x_int = torch.round(x / sq * scale_x).clamp(-127, 127)
185198
return x_int / scale_x * sq
186-
elif activ_quant_type == "per_tensor_asymm":
199+
if activ_quant_type == "per_tensor_asymm":
187200
scale_x = 255 / (a_cv - a_cvn)
188201
zp_x = a_cvn * scale_x
189202
x_int = torch.round(x / sq * scale_x - zp_x).clamp(0, 255)
190203
return (x_int + zp_x) / scale_x * sq
191-
elif activ_quant_type == "per_token":
204+
if activ_quant_type == "per_token":
192205
x_sq = x / sq
193206
a_cv_per_token = x_sq.abs().max(dim=-1, keepdim=True)[0]
194207
scale_x = 127 / a_cv_per_token
195208
x_int = torch.round(x_sq * scale_x).clamp(-127, 127)
196209
return x_int / scale_x * sq
197-
else:
198-
raise NotImplementedError(
199-
f"activation quantizantion type {activ_quant_type} is not supported"
200-
)
210+
raise NotImplementedError(
211+
f"activation quantizantion type {activ_quant_type} is not supported"
212+
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dependencies = [
3636
"huggingface_hub",
3737
"pandas",
3838
"safetensors",
39+
"ibm-fms>=0.0.8"
3940
]
4041

4142
[project.optional-dependencies]

0 commit comments

Comments
 (0)