Skip to content

Commit 69d5636

Browse files
committed
reformat with new pre-commit after PR#39
1 parent 6de9e86 commit 69d5636

22 files changed

+36
-54
lines changed

chebai/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
2-
import torch
32
from typing import Any
43

4+
import torch
5+
56
# Get the absolute path of the current file's directory
67
MODULE_PATH = os.path.abspath(os.path.dirname(__file__))
78

chebai/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import json
22
import os
3+
from typing import Any, Dict, List, Literal, Union
34

45
import torch
5-
from typing import Any, Dict, List, Union, Literal
66
from lightning.pytorch.callbacks import BasePredictionWriter
77

88

chebai/callbacks/model_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
23
from lightning.fabric.utilities.cloud_io import _is_dir
34
from lightning.fabric.utilities.types import _PATH
45
from lightning.pytorch import LightningModule, Trainer

chebai/callbacks/prediction_callback.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from lightning.pytorch import Trainer, LightningModule
2-
from lightning.pytorch.callbacks import BasePredictionWriter
3-
import torch
4-
51
import os
62
import pickle
7-
from typing import Sequence, Any, Literal
3+
from typing import Any, Literal, Sequence
4+
5+
import torch
6+
from lightning.pytorch import LightningModule, Trainer
7+
from lightning.pytorch.callbacks import BasePredictionWriter
88

99

1010
class PredictionWriter(BasePredictionWriter):

chebai/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Dict, Set
2+
23
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
4+
35
from chebai.trainer.CustomTrainer import CustomTrainer
46

57

chebai/loss/bce_weighted.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import os
12
from typing import Optional
23

4+
import pandas as pd
35
import torch
46

57
from chebai.preprocessing.datasets.base import XYBaseDataModule
68
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
7-
import pandas as pd
8-
import os
99

1010

1111
class BCEWeighted(torch.nn.BCEWithLogitsLoss):

chebai/loss/mixed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from torch import nn
21
import torch
2+
from torch import nn
33

44

55
class MixedDataLoss(nn.Module):

chebai/loss/semantic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
import math
33
import os
44
import pickle
5+
from typing import List, Literal, Union
56

67
import torch
78

8-
from typing import Literal, Union, List
9-
109
from chebai.loss.bce_weighted import BCEWeighted
1110
from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor
1211
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed

chebai/models/base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from typing import Optional, Dict, Union, Any
21
import logging
2+
from typing import Any, Dict, Optional, Union
33

44
import torch
5-
6-
from torchmetrics import Metric
7-
85
from lightning.pytorch.core.module import LightningModule
6+
from torchmetrics import Metric
97

108
from chebai.preprocessing.structures import XYData
119

@@ -226,7 +224,8 @@ def _execute(
226224
sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False.
227225
228226
Returns:
229-
Dict[str, Union[torch.Tensor, Any]]: A dictionary containing the processed data, labels, model output, predictions, and loss (if applicable).
227+
Dict[str, Union[torch.Tensor, Any]]: A dictionary containing the processed data, labels, model output,
228+
predictions, and loss (if applicable).
230229
"""
231230
assert isinstance(batch, XYData)
232231
batch = batch.to(self.device)

chebai/models/electra.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
import logging
22
from math import pi
33
from tempfile import TemporaryDirectory
4-
5-
from typing import Any, Dict, Tuple, Optional
6-
7-
from torch import nn, Tensor
4+
from typing import Any, Dict, Optional, Tuple
85

96
import torch
10-
7+
from torch import Tensor, nn
118
from torch.nn.utils.rnn import pad_sequence
129
from transformers import (
1310
ElectraConfig,

0 commit comments

Comments
 (0)