|
36 | 36 | from datasets import load_from_disk |
37 | 37 | from huggingface_hub.errors import HFValidationError |
38 | 38 | from torch.cuda import OutOfMemoryError |
39 | | -from transformers import AutoTokenizer |
| 39 | +from transformers import ( |
| 40 | + AutoModelForMaskedLM, |
| 41 | + AutoModelForQuestionAnswering, |
| 42 | + AutoTokenizer, |
| 43 | +) |
40 | 44 | import torch |
41 | 45 | import transformers |
42 | 46 |
|
@@ -204,9 +208,23 @@ def run_fp8(model_args, data_args, opt_args, fp8_args): |
204 | 208 |
|
205 | 209 | logger = set_log_level(opt_args.log_level, "fms_mo.run_fp8") |
206 | 210 |
|
207 | | - model = SparseAutoModelForCausalLM.from_pretrained( |
208 | | - model_args.model_name_or_path, torch_dtype=model_args.torch_dtype |
209 | | - ) |
| 211 | + if model_args.task_type == "lm": |
| 212 | + model = SparseAutoModelForCausalLM.from_pretrained( |
| 213 | + model_args.model_name_or_path, |
| 214 | + torch_dtype=model_args.torch_dtype, |
| 215 | + ) |
| 216 | + elif model_args.task_type == "qa": |
| 217 | + model = AutoModelForQuestionAnswering.from_pretrained( |
| 218 | + model_args.model_name_or_path, |
| 219 | + torch_dtype=model_args.torch_dtype, |
| 220 | + ) |
| 221 | + elif model_args.task_type == "mlm": |
| 222 | + model = AutoModelForMaskedLM.from_pretrained( |
| 223 | + model_args.model_name_or_path, |
| 224 | + torch_dtype=model_args.torch_dtype, |
| 225 | + ) |
| 226 | + else: |
| 227 | + raise ValueError(f"Unsupported task: {model_args.task_type}") |
210 | 228 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) |
211 | 229 |
|
212 | 230 | recipe = QuantizationModifier( |
|
0 commit comments