Skip to content

Commit 252f645

Browse files
committed
base torch
1 parent 8d8e103 commit 252f645

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

autointent/_dump_tools/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import numpy.typing as npt
7+
import torch
78

89
from autointent.configs import CrossEncoderConfig, EmbedderConfig
910
from autointent.context.optimization_info import Artifact
@@ -108,6 +109,8 @@ def dump(
108109
simple_attrs[key] = val
109110
elif isinstance(val, np.ndarray):
110111
arrays[key] = val
112+
elif isinstance(val, torch.Tensor):
113+
arrays[key] = val.cpu().numpy()
111114
else:
112115
# Use the appropriate dumper for complex objects
113116
Dumper._dump_single_object(key, val, path, exists_ok, raise_errors)

autointent/modules/scoring/_gcn/gcn_model.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import torch
55
import torch.nn as nn
66
from pydantic import BaseModel
7+
from typing_extensions import Self
78

89
from autointent._utils import detect_device
10+
from autointent._wrappers import BaseTorchModuleWithVocab
911

1012

1113
class GCNModelDumpMetadata(BaseModel):
@@ -28,7 +30,7 @@ def forward(self, adj_matrix, features):
2830
return output
2931

3032

31-
class TextMLGCN(nn.Module):
33+
class TextMLGCN(BaseTorchModuleWithVocab):
3234
_metadata_dict_name = "metadata.json"
3335
_state_dict_name = "state_dict.pt"
3436

@@ -41,7 +43,7 @@ def __init__(
4143
p_reweight: float,
4244
tau_threshold: float,
4345
):
44-
super().__init__()
46+
super().__init__(embed_dim=bert_feature_dim)
4547
self.num_classes = num_classes
4648
self.p_reweight = p_reweight
4749
self.tau_threshold = tau_threshold
@@ -93,7 +95,7 @@ def set_correlation_matrix(self, train_labels):
9395
)
9496
self.correlation_matrix.data.copy_(corr_matrix)
9597

96-
def forward(self, bert_features, label_embeddings):
98+
def forward(self, bert_features, label_embeddings): # type: ignore
9799
classifiers = label_embeddings
98100
for i in range(len(self.gcn_layers)):
99101
classifiers = self.gcn_layers[i](self.correlation_matrix, classifiers)
@@ -102,10 +104,6 @@ def forward(self, bert_features, label_embeddings):
102104
logits = torch.matmul(bert_features, classifiers.T)
103105
return logits
104106

105-
@property
106-
def device(self) -> torch.device:
107-
return next(self.parameters()).device
108-
109107
def dump(self, path: Path) -> None:
110108
metadata = GCNModelDumpMetadata(
111109
num_classes=self.num_classes,
@@ -124,7 +122,7 @@ def dump(self, path: Path) -> None:
124122
self.to(device)
125123

126124
@classmethod
127-
def load(cls, path: Path, device: str | None = None) -> "TextMLGCN":
125+
def load(cls, path: Path, device: str | None = None) -> Self:
128126
with (path / cls._metadata_dict_name).open() as file:
129127
metadata = GCNModelDumpMetadata(**json.load(file))
130128
device = device or detect_device()

autointent/modules/scoring/_gcn/gcn_scorer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from pydantic import PositiveInt
77
from torch import nn
88
from torch.utils.data import DataLoader, TensorDataset
9+
from typing_extensions import Self
910

1011
from autointent import Context, Embedder
11-
from autointent.configs import EmbedderConfig, TaskTypeEnum, TorchTrainingConfig
12+
from autointent.configs import CrossEncoderConfig, EmbedderConfig, TaskTypeEnum, TorchTrainingConfig
1213
from autointent.custom_types import ListOfLabels
1314
from autointent.modules.base import BaseScorer
1415
from autointent.modules.scoring._gcn.gcn_model import TextMLGCN
@@ -162,3 +163,15 @@ def clear_cache(self) -> None:
162163
if hasattr(self, "_label_embedder"):
163164
self._label_embedder.clear_ram()
164165
del self._label_embedder
166+
167+
@classmethod
168+
def load(
169+
cls,
170+
path: str,
171+
embedder_config: EmbedderConfig | None = None,
172+
cross_encoder_config: CrossEncoderConfig | None = None,
173+
) -> Self:
174+
instance = super().load(path, embedder_config, cross_encoder_config)
175+
if hasattr(instance, "_label_embeddings"):
176+
instance._label_embeddings = torch.tensor(instance._label_embeddings).to(instance.torch_config.device)
177+
return instance

0 commit comments

Comments
 (0)