Skip to content

Commit 7d44074

Browse files
committed
Add QA and MaskedLM task encoder architectures in run_fp8
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent bcee5f3 commit 7d44074

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

fms_mo/run_quant.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@
3636
from datasets import load_from_disk
3737
from huggingface_hub.errors import HFValidationError
3838
from torch.cuda import OutOfMemoryError
39-
from transformers import AutoTokenizer
39+
from transformers import (
40+
AutoModelForMaskedLM,
41+
AutoModelForQuestionAnswering,
42+
AutoTokenizer,
43+
)
4044
import torch
4145
import transformers
4246

@@ -204,9 +208,23 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
204208

205209
logger = set_log_level(opt_args.log_level, "fms_mo.run_fp8")
206210

207-
model = SparseAutoModelForCausalLM.from_pretrained(
208-
model_args.model_name_or_path, torch_dtype=model_args.torch_dtype
209-
)
211+
if model_args.task_type == "lm":
212+
model = SparseAutoModelForCausalLM.from_pretrained(
213+
model_args.model_name_or_path,
214+
torch_dtype=model_args.torch_dtype,
215+
)
216+
elif model_args.task_type == "qa":
217+
model = AutoModelForQuestionAnswering.from_pretrained(
218+
model_args.model_name_or_path,
219+
torch_dtype=model_args.torch_dtype,
220+
)
221+
elif model_args.task_type == "mlm":
222+
model = AutoModelForMaskedLM.from_pretrained(
223+
model_args.model_name_or_path,
224+
torch_dtype=model_args.torch_dtype,
225+
)
226+
else:
227+
raise ValueError(f"Unsupported task: {model_args.task_type}")
210228
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
211229

212230
recipe = QuantizationModifier(

fms_mo/training_args.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ class ModelArguments(TypeChecker):
5555
"""Dataclass for model related arguments."""
5656

5757
model_name_or_path: str = field(default="facebook/opt-125m")
58+
task_type: str = field(
59+
default="lm",
60+
metadata={
61+
"choices": ["lm", "qa", "mlm"],
62+
"help": (
63+
"Instantiate model for selected task: 'lm' (language modeling), 'qa' "
64+
"(question answering, for encoders), 'mlm' (masked language modeling, "
65+
"for encoders)."
66+
),
67+
},
68+
)
5869
torch_dtype: str = field(default="bfloat16")
5970
device_map: Optional[str] = field(
6071
default=None,

0 commit comments

Comments
 (0)