Skip to content

Commit 952b6d4

Browse files
Merge pull request #148 from andrea-fasoli/fp8_new_tasks
feat: add QA and MaskedLM task for FP8 encoder instantiation
2 parents 6c54b37 + 7d44074 commit 952b6d4

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

fms_mo/run_quant.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
# Third Party
3636
from datasets import load_from_disk
3737
from torch.cuda import OutOfMemoryError
38-
from transformers import AutoTokenizer
38+
from transformers import (
39+
AutoModelForMaskedLM,
40+
AutoModelForQuestionAnswering,
41+
AutoTokenizer,
42+
)
3943
import torch
4044
import transformers
4145

@@ -204,9 +208,23 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
204208

205209
logger = set_log_level(opt_args.log_level, "fms_mo.run_fp8")
206210

207-
model = SparseAutoModelForCausalLM.from_pretrained(
208-
model_args.model_name_or_path, torch_dtype=model_args.torch_dtype
209-
)
211+
if model_args.task_type == "lm":
212+
model = SparseAutoModelForCausalLM.from_pretrained(
213+
model_args.model_name_or_path,
214+
torch_dtype=model_args.torch_dtype,
215+
)
216+
elif model_args.task_type == "qa":
217+
model = AutoModelForQuestionAnswering.from_pretrained(
218+
model_args.model_name_or_path,
219+
torch_dtype=model_args.torch_dtype,
220+
)
221+
elif model_args.task_type == "mlm":
222+
model = AutoModelForMaskedLM.from_pretrained(
223+
model_args.model_name_or_path,
224+
torch_dtype=model_args.torch_dtype,
225+
)
226+
else:
227+
raise ValueError(f"Unsupported task: {model_args.task_type}")
210228
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
211229

212230
recipe = QuantizationModifier(

fms_mo/training_args.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ class ModelArguments(TypeChecker):
5555
"""Dataclass for model related arguments."""
5656

5757
model_name_or_path: str = field(default="facebook/opt-125m")
58+
task_type: str = field(
59+
default="lm",
60+
metadata={
61+
"choices": ["lm", "qa", "mlm"],
62+
"help": (
63+
"Instantiate model for selected task: 'lm' (language modeling), 'qa' "
64+
"(question answering, for encoders), 'mlm' (masked language modeling, "
65+
"for encoders)."
66+
),
67+
},
68+
)
5869
torch_dtype: str = field(default="bfloat16")
5970
device_map: Optional[str] = field(
6071
default=None,

0 commit comments

Comments
 (0)