|
15 | 15 |
|
16 | 16 | # Constants |
17 | 17 | APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"] |
18 | | -MAX_LENGTH = 512 |
| 18 | +MAX_LENGTH = 1024 |
19 | 19 |
|
20 | 20 | # Device selection |
21 | 21 | device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") |
@@ -233,6 +233,18 @@ def inference(model, tokenizer, prompt, effort_levels): |
233 | 233 | return results |
234 | 234 |
|
235 | 235 | def main(args): |
| 236 | + |
| 237 | + if args.push_to_hub: |
| 238 | + base_model = AutoModel.from_pretrained(args.model_name) |
| 239 | + tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
| 240 | + # best_model = OptILMClassifier(base_model, num_labels=len(APPROACHES)) |
| 241 | + # best_model.to(device) |
| 242 | + # load_model(best_model, "best_model.safetensors") |
| 243 | + # we just push the base model and then upload the safetensors file manually as OptILMClassifier class doesn't have a push_to_hub method. |
| 244 | + base_model.push_to_hub(args.hub_model_id) |
| 245 | + tokenizer.push_to_hub(args.hub_model_id) |
| 246 | + return |
| 247 | + |
236 | 248 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
237 | 249 | dataset = load_and_preprocess_data(tokenizer) |
238 | 250 |
|
@@ -273,15 +285,6 @@ def main(args): |
273 | 285 |
|
274 | 286 | print(f"\nBest performing model was from fold {best_fold} with validation accuracy {best_val_accuracy:.4f}") |
275 | 287 |
|
276 | | - if args.push_to_hub: |
277 | | - base_model = AutoModel.from_pretrained(args.model_name) |
278 | | - # best_model = OptILMClassifier(base_model, num_labels=len(APPROACHES)) |
279 | | - # best_model.to(device) |
280 | | - # load_model(best_model, "best_model.safetensors") |
281 | | - # we just push the base model and then upload the safetensors file manually as OptILMClassifier class doesn't have a push_to_hub method. |
282 | | - base_model.push_to_hub(args.hub_model_id) |
283 | | - tokenizer.push_to_hub(args.hub_model_id) |
284 | | - |
285 | 288 | # Load the best model for inference |
286 | 289 | base_model = AutoModel.from_pretrained(args.model_name) |
287 | 290 | best_model = OptILMClassifier(base_model, num_labels=len(APPROACHES)) |
|
0 commit comments