File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change 11import gc
2+ import os
23import traceback
34from datetime import datetime
4- from typing import List , LiteralString
5+ from typing import List , LiteralString , Optional , Tuple
56
67import pandas as pd
8+ import torch
9+ import wandb
710from torchmetrics .functional .classification import (
811 multilabel_auroc ,
912 multilabel_average_precision ,
1417from chebai .models import Electra
1518from chebai .preprocessing .datasets .base import _DynamicDataset
1619from chebai .preprocessing .datasets .chebi import ChEBIOver100
17- from chebai .preprocessing .datasets .pubchem import PubChemKMeans
18- from chebai .result .utils import *
20+ from chebai .result .utils import (
21+ evaluate_model ,
22+ get_checkpoint_from_wandb ,
23+ load_results_from_buffer ,
24+ )
1925
2026DEVICE = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
2127
@@ -739,6 +745,8 @@ def run_fuzzy_loss(tag="fuzzy_loss", skip_first_n=0):
739745
740746
741747if __name__ == "__main__" :
748+ import sys
749+
742750 if len (sys .argv ) > 2 :
743751 run_fuzzy_loss (sys .argv [1 ], int (sys .argv [2 ]))
744752 elif len (sys .argv ) > 1 :
You can’t perform that action at this time.
0 commit comments