Skip to content

Commit ea339b4

Browse files
committed
added granite support; fixed adapters to ignore model_config
Signed-off-by: JOSHUA ROSENKRANZ <[email protected]>
1 parent 4335a9d commit ea339b4

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

fms_mo/aiu_addons/__init__.py

Whitespace-only changes.

fms_mo/aiu_addons/gptq/gptq_aiu_adapter.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
def _gptq_qweights_transpose_aiu(
25-
input_sd: Mapping[str, torch.Tensor],
25+
input_sd: Mapping[str, torch.Tensor], **kwargs
2626
) -> Mapping[str, torch.Tensor]:
2727
new_sd = {}
2828
for name, param in input_sd.items():
@@ -41,6 +41,7 @@ def _gptq_qweights_transpose_aiu(
4141
serialization.register_adapter_step(
4242
"gpt_bigcode", "gptq_qweights_transpose_aiu", _gptq_qweights_transpose_aiu
4343
)
44+
serialization.register_adapter_step("granite", "gptq_qweights_transpose_aiu", _gptq_qweights_transpose_aiu)
4445
serialization.register_adapter(
4546
"llama",
4647
"hf_gptq_aiu",
@@ -57,3 +58,9 @@ def _gptq_qweights_transpose_aiu(
5758
"hf_gptq_aiu",
5859
["hf_to_fms_names", "weight_fusion", "gptq_qweights_transpose_aiu"],
5960
)
61+
serialization.register_adapter(
62+
"granite",
63+
"hf_gptq_aiu",
64+
["hf_to_fms_names", "hf_to_fms_rope", "hf_gptq_fusion_check", "weight_fusion", "gptq_qweights_transpose_aiu"]
65+
)
66+

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
def _int8_qparams_aiu(
25-
input_sd: Mapping[str, torch.Tensor],
25+
input_sd: Mapping[str, torch.Tensor], **kwargs
2626
) -> Mapping[str, torch.Tensor]:
2727
new_sd = {}
2828
modules_seen = set()

0 commit comments

Comments
 (0)