Skip to content

Commit 61252a6

Browse files
committed
Add support in addons fro INT8 Granite arch
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent dea5d56 commit 61252a6

File tree

2 files changed

+20
-43
lines changed

2 files changed

+20
-43
lines changed

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -97,41 +97,16 @@ def _add_defaults_and_concat(
9797
)
9898

9999

100-
# registration of new adapter steps for each architecture
101-
serialization.register_adapter_step("llama", "int8_qparams_aiu", _int8_qparams_aiu)
102-
serialization.register_adapter_step(
103-
"gpt_bigcode", "int8_qparams_aiu", _int8_qparams_aiu
104-
)
105-
serialization.register_adapter_step("roberta", "int8_qparams_aiu", _int8_qparams_aiu)
106-
serialization.register_adapter_step(
107-
"roberta_question_answering",
108-
"int8_qparams_aiu",
109-
_int8_qparams_aiu,
110-
)
111-
112-
# registration of multi-step adapter for each architecture
113-
serialization.register_adapter(
114-
"llama",
115-
"fms_mo",
116-
[
117-
"hf_to_fms_names",
118-
"hf_to_fms_rope",
119-
"weight_fusion",
120-
"int8_qparams_aiu",
121-
],
122-
)
123-
serialization.register_adapter(
124-
"gpt_bigcode", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
125-
)
126-
serialization.register_adapter(
127-
"roberta", "fms_mo", ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
128-
)
129-
serialization.register_adapter(
130-
"roberta_question_answering",
131-
"fms_mo",
132-
[
133-
"hf_to_fms_names",
134-
"weight_fusion",
135-
"int8_qparams_aiu",
136-
],
137-
)
100+
# registration of new adapter step and adapter for each architecture
101+
for arch in ["llama", "gpt_bigcode", "granite", "roberta", "roberta_question_answering"]:
102+
serialization.register_adapter_step(arch, "int8_qparams_aiu", _int8_qparams_aiu)
103+
if arch in ["llama", "granite"]:
104+
steps_to_register = [
105+
"hf_to_fms_names",
106+
"hf_to_fms_rope",
107+
"weight_fusion",
108+
"int8_qparams_aiu",
109+
]
110+
else:
111+
steps_to_register = ["hf_to_fms_names", "weight_fusion", "int8_qparams_aiu"]
112+
serialization.register_adapter(arch, "fms_mo", steps_to_register)

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ def __init__(
8484
"weight",
8585
torch.zeros(out_features, in_features, dtype=torch.int8),
8686
)
87-
if bias:
88-
self.register_buffer(
89-
"bias", torch.zeros((out_features), dtype=torch.float16)
90-
)
87+
88+
self.has_bias = bias
89+
bias_size = out_features if self.has_bias else 1
90+
self.register_buffer(
91+
"bias", torch.zeros((bias_size), dtype=torch.float16)
92+
)
9193

9294
if config.weight_per_channel:
9395
w_clip_size = out_features
@@ -192,7 +194,7 @@ def __repr__(self) -> str:
192194
return (
193195
f"{self.__class__.__name__}"
194196
f"(in={self.in_features}, out={self.out_features}, "
195-
f"bias={self.bias is not None}, wq={self.weight_quant_type}, "
197+
f"bias={self.has_bias}, wq={self.weight_quant_type}, "
196198
f"aq={self.activ_quant_type}, smoothq={self.smoothquant}, "
197199
f"op={self.aiu_op})"
198200
)

0 commit comments

Comments
 (0)