Skip to content

Commit 50b6ce3

Browse files
Merge pull request #92 from andrea-fasoli/int8_granite_addon
feat: int8 granite addon
2 parents 848b766 + 876e1c5 commit 50b6ce3

File tree

2 files changed

+22
-41
lines changed

2 files changed

+22
-41
lines changed

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -97,41 +97,22 @@ def _add_defaults_and_concat(
9797
)
9898

9999

100-
# registration of new adapter steps for each architecture
101-
serialization.register_adapter_step("llama", "int8_qparams_aiu", _int8_qparams_aiu)
102-
serialization.register_adapter_step(
103-
"gpt_bigcode", "int8_qparams_aiu", _int8_qparams_aiu
104-
)
105-
serialization.register_adapter_step("roberta", "int8_qparams_aiu", _int8_qparams_aiu)
106-
serialization.register_adapter_step(
107-
"roberta_question_answering",
108-
"int8_qparams_aiu",
109-
_int8_qparams_aiu,
110-
)
111-
112-
# registration of multi-step adapter for each architecture
113-
serialization.register_adapter(
100+
# registration of new adapter step and adapter for each architecture
101+
for arch in [
114102
"llama",
115-
"fms_mo",
116-
[
117-
"hf_to_fms_names",
118-
"hf_to_fms_rope",
119-
"weight_fusion",
120-
"int8_qparams_aiu",
121-
],
122-
)
123-
serialization.register_adapter(
124-
"gpt_bigcode", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
125-
)
126-
serialization.register_adapter(
127-
"roberta", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
128-
)
129-
serialization.register_adapter(
103+
"gpt_bigcode",
104+
"granite",
105+
"roberta",
130106
"roberta_question_answering",
131-
"fms_mo",
132-
[
133-
"hf_to_fms_names",
134-
"weight_fusion",
135-
"int8_qparams_aiu",
136-
],
137-
)
107+
]:
108+
serialization.register_adapter_step(arch, "int8_qparams_aiu", _int8_qparams_aiu)
109+
if arch in ["llama", "granite"]:
110+
steps_to_register = [
111+
"hf_to_fms_names",
112+
"hf_to_fms_rope",
113+
"weight_fusion",
114+
"int8_qparams_aiu",
115+
]
116+
else:
117+
steps_to_register = ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
118+
serialization.register_adapter(arch, "fms_mo", steps_to_register)

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ def __init__(
8484
"weight",
8585
torch.zeros(out_features, in_features, dtype=torch.int8),
8686
)
87-
if bias:
88-
self.register_buffer(
89-
"bias", torch.zeros((out_features), dtype=torch.float16)
90-
)
87+
88+
self.has_bias = bias
89+
bias_size = out_features if self.has_bias else 1
90+
self.register_buffer("bias", torch.zeros((bias_size), dtype=torch.float16))
9191

9292
if config.weight_per_channel:
9393
w_clip_size = out_features
@@ -192,7 +192,7 @@ def __repr__(self) -> str:
192192
return (
193193
f"{self.__class__.__name__}"
194194
f"(in={self.in_features}, out={self.out_features}, "
195-
f"bias={self.bias is not None}, wq={self.weight_quant_type}, "
195+
f"bias={self.has_bias}, wq={self.weight_quant_type}, "
196196
f"aq={self.activ_quant_type}, smoothq={self.smoothquant}, "
197197
f"op={self.aiu_op})"
198198
)

0 commit comments

Comments
 (0)