Skip to content

Commit c3bbdfa

Browse files
committed
fix imports
1 parent 4f1f995 commit c3bbdfa

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

chebai/result/analyse_sem.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import gc
2+
import os
23
import traceback
34
from datetime import datetime
4-
from typing import List, LiteralString
5+
from typing import List, LiteralString, Optional, Tuple
56

67
import pandas as pd
8+
import torch
9+
import wandb
710
from torchmetrics.functional.classification import (
811
multilabel_auroc,
912
multilabel_average_precision,
@@ -14,8 +17,11 @@
1417
from chebai.models import Electra
1518
from chebai.preprocessing.datasets.base import _DynamicDataset
1619
from chebai.preprocessing.datasets.chebi import ChEBIOver100
17-
from chebai.preprocessing.datasets.pubchem import PubChemKMeans
18-
from chebai.result.utils import *
20+
from chebai.result.utils import (
21+
evaluate_model,
22+
get_checkpoint_from_wandb,
23+
load_results_from_buffer,
24+
)
1925

2026
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2127

@@ -739,6 +745,8 @@ def run_fuzzy_loss(tag="fuzzy_loss", skip_first_n=0):
739745

740746

741747
if __name__ == "__main__":
748+
import sys
749+
742750
if len(sys.argv) > 2:
743751
run_fuzzy_loss(sys.argv[1], int(sys.argv[2]))
744752
elif len(sys.argv) > 1:

0 commit comments

Comments
 (0)