Skip to content

Commit 8562e93

Browse files
authored
Feat/gcn scorer (#261)
* main code for gcn * multiclass fix * base torch * base changes after review * fix tests * two errors remaining * ruff fix * test fix * mypy1 * mypy2 * mypy3 * mypy4 * mypy5 * mypy6 * mypy7 * mypy8 * mypy9 * vscode * main review-3 * mypy fix * comments * replace * fix lint mypy * mypy fix * return_tensors * mypy fix * overload fix
1 parent ed82eae commit 8562e93

File tree

18 files changed

+629
-144
lines changed

18 files changed

+629
-144
lines changed

.vscode/settings.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,10 @@
88
"*.yaml",
99
"!*/.github/*/*.yaml"
1010
]
11-
}
11+
},
12+
"python.testing.pytestArgs": [
13+
"."
14+
],
15+
"python.testing.unittestEnabled": false,
16+
"python.testing.pytestEnabled": true
1217
}

src/autointent/_dump_tools/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import numpy.typing as npt
7+
import torch
78

89
from autointent.configs import CrossEncoderConfig, EmbedderConfig
910
from autointent.context.optimization_info import Artifact
@@ -108,6 +109,8 @@ def dump(
108109
simple_attrs[key] = val
109110
elif isinstance(val, np.ndarray):
110111
arrays[key] = val
112+
elif isinstance(val, torch.Tensor):
113+
arrays[key] = val.cpu().numpy()
111114
else:
112115
# Use the appropriate dumper for complex objects
113116
Dumper._dump_single_object(key, val, path, exists_ok, raise_errors)

src/autointent/_dump_tools/unit_dumpers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222

2323
from autointent import Embedder, Ranker, VectorIndex
24-
from autointent._wrappers import BaseTorchModuleWithVocab
24+
from autointent._wrappers import BaseTorchModule
2525
from autointent.schemas import TagsList
2626

2727
from .base import BaseObjectDumper, ModuleSimpleAttributes
@@ -276,11 +276,11 @@ def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
276276
return isinstance(obj, PreTrainedTokenizer | PreTrainedTokenizerFast)
277277

278278

279-
class TorchModelDumper(BaseObjectDumper[BaseTorchModuleWithVocab]):
279+
class TorchModelDumper(BaseObjectDumper[BaseTorchModule]):
280280
dir_or_file_name = "torch_models"
281281

282282
@staticmethod
283-
def dump(obj: BaseTorchModuleWithVocab, path: Path, exists_ok: bool) -> None:
283+
def dump(obj: BaseTorchModule, path: Path, exists_ok: bool) -> None:
284284
path.mkdir(parents=True, exist_ok=exists_ok)
285285
class_info = {
286286
"module": obj.__class__.__module__,
@@ -291,16 +291,16 @@ def dump(obj: BaseTorchModuleWithVocab, path: Path, exists_ok: bool) -> None:
291291
obj.dump(path)
292292

293293
@staticmethod
294-
def load(path: Path, **kwargs: Any) -> BaseTorchModuleWithVocab: # noqa: ANN401, ARG004
294+
def load(path: Path, **kwargs: Any) -> BaseTorchModule: # noqa: ANN401, ARG004
295295
with (path / "class_info.json").open("r") as f:
296296
class_info = json.load(f)
297297
module = importlib.import_module(class_info["module"])
298-
model_class: BaseTorchModuleWithVocab = getattr(module, class_info["name"])
298+
model_class: BaseTorchModule = getattr(module, class_info["name"])
299299
return model_class.load(path)
300300

301301
@classmethod
302302
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
303-
return isinstance(obj, BaseTorchModuleWithVocab)
303+
return isinstance(obj, BaseTorchModule)
304304

305305

306306
class CatBoostDumper(BaseObjectDumper[CatBoostClassifier]):

src/autointent/_wrappers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .embedder import Embedder
33
from .vector_index import VectorIndex
44
from .base_torch_module import BaseTorchModuleWithVocab
5+
from .base_torch_module import BaseTorchModule
56

6-
__all__ = ["BaseTorchModuleWithVocab", "Embedder", "Ranker", "VectorIndex"]
7+
__all__ = ["BaseTorchModule", "BaseTorchModuleWithVocab", "Embedder", "Ranker", "VectorIndex"]

src/autointent/_wrappers/base_torch_module.py

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,52 @@
1313
from autointent.configs import VocabConfig
1414

1515

16-
class BaseTorchModuleWithVocab(nn.Module, ABC):
16+
class BaseTorchModule(nn.Module, ABC):
17+
@abstractmethod
18+
def forward(self, text: torch.Tensor) -> torch.Tensor:
19+
"""Compute sentence embeddings for given text.
20+
21+
Args:
22+
text: torch tensor of shape (B, T), token ids
23+
24+
Returns:
25+
embeddings of shape (B, H)
26+
"""
27+
28+
@abstractmethod
29+
def dump(self, path: Path) -> None:
30+
"""Dump torch module to disk.
31+
32+
This method encapsulates all the logic of dumping module's weights and
33+
hyperparameters required for initialization from disk and nice inference.
34+
35+
Args:
36+
path: path in file system
37+
"""
38+
39+
@classmethod
40+
@abstractmethod
41+
def load(cls, path: Path, device: str | None = None) -> Self:
42+
"""Load torch module from disk.
43+
44+
This method loads all weights and hyperparameters required for
45+
initialization from disk and inference.
46+
47+
Args:
48+
path: path in file system
49+
device: torch notation for CPU, CUDA, MPS, etc. By default, it is inferred automatically.
50+
"""
51+
52+
@property
53+
def device(self) -> torch.device:
54+
"""Torch device object where this module resides."""
55+
return next(self.parameters()).device
56+
57+
58+
class BaseTorchModuleWithVocab(BaseTorchModule, ABC):
1759
def __init__(
1860
self,
19-
embed_dim: int,
61+
embed_dim: int | None = None,
2062
vocab_config: VocabConfig | None = None,
2163
) -> None:
2264
super().__init__()
@@ -34,6 +76,9 @@ def __init__(
3476

3577
def set_vocab(self, vocab: dict[str, Any]) -> None:
3678
"""Save vocabulary into module's attributes and initialize embeddings matrix."""
79+
if self.embed_dim is None:
80+
msg = "embed_dim must be set to initialize embeddings"
81+
raise ValueError(msg)
3782
self.vocab_config.vocab = vocab
3883
self.embedding = nn.Embedding(
3984
num_embeddings=len(self.vocab_config.vocab),
@@ -43,6 +88,10 @@ def set_vocab(self, vocab: dict[str, Any]) -> None:
4388

4489
def build_vocab(self, utterances: list[str]) -> None:
4590
"""Build vocabulary from training utterances."""
91+
if self.embed_dim is None:
92+
msg = "embed_dim must be set to initialize embeddings"
93+
raise ValueError(msg)
94+
4695
if self.vocab_config.vocab is not None:
4796
msg = "Vocab is already built."
4897
raise RuntimeError(msg)
@@ -80,43 +129,3 @@ def text_to_indices(self, utterances: list[str]) -> list[list[int]]:
80129
seq = seq + [self.vocab_config.padding_idx] * (self.vocab_config.max_seq_length - len(seq))
81130
sequences.append(seq)
82131
return sequences
83-
84-
@abstractmethod
85-
def forward(self, text: torch.Tensor) -> torch.Tensor:
86-
"""Compute sentence embeddings for given text.
87-
88-
Args:
89-
text: torch tensor of shape (B, T), token ids
90-
91-
Returns:
92-
embeddings of shape (B, H)
93-
"""
94-
95-
@abstractmethod
96-
def dump(self, path: Path) -> None:
97-
"""Dump torch module to disk.
98-
99-
This method encapsulates all the logic of dumping module's weights and
100-
hyperparameters required for initialization from disk and nice inference.
101-
102-
Args:
103-
path: path in file system
104-
"""
105-
106-
@classmethod
107-
@abstractmethod
108-
def load(cls, path: Path, device: str | None = None) -> Self:
109-
"""Load torch module from disk.
110-
111-
This method loads all weights and hyperparameters required for
112-
initialization from disk and inference.
113-
114-
Args:
115-
path: path in file system
116-
device: torch notation for CPU, CUDA, MPS, etc. By default, it is inferred automatically.
117-
"""
118-
119-
@property
120-
def device(self) -> torch.device:
121-
"""Torch device object where this module resides."""
122-
return next(self.parameters()).device

src/autointent/_wrappers/embedder.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import tempfile
1111
from functools import lru_cache
1212
from pathlib import Path
13+
from typing import Literal, cast, overload
1314
from uuid import uuid4
1415

1516
import huggingface_hub
@@ -235,15 +236,28 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
235236

236237
return cls(EmbedderConfig(**kwargs))
237238

238-
def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]:
239+
@overload
240+
def embed(
241+
self, utterances: list[str], task_type: TaskTypeEnum | None = None, *, return_tensors: Literal[True]
242+
) -> torch.Tensor: ...
243+
244+
@overload
245+
def embed(
246+
self, utterances: list[str], task_type: TaskTypeEnum | None = None, *, return_tensors: Literal[False] = False
247+
) -> npt.NDArray[np.float32]: ...
248+
249+
def embed(
250+
self, utterances: list[str], task_type: TaskTypeEnum | None = None, return_tensors: bool = False
251+
) -> npt.NDArray[np.float32] | torch.Tensor:
239252
"""Calculate embeddings for a list of utterances.
240253
241254
Args:
242255
utterances: List of input texts to calculate embeddings for.
243256
task_type: Type of task for which embeddings are calculated.
257+
return_tensors: If True, return a PyTorch tensor; otherwise, return a numpy array.
244258
245259
Returns:
246-
A numpy array of embeddings.
260+
A numpy array or PyTorch tensor of embeddings.
247261
"""
248262
if len(utterances) == 0:
249263
msg = "Empty input"
@@ -263,7 +277,10 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
263277
embeddings_path = _get_embeddings_path(hasher.hexdigest())
264278
if embeddings_path.exists():
265279
logger.debug("loading embeddings from %s", str(embeddings_path))
266-
return np.load(embeddings_path) # type: ignore[no-any-return]
280+
embeddings_np = cast(npt.NDArray[np.float32], np.load(embeddings_path))
281+
if return_tensors:
282+
return torch.from_numpy(embeddings_np).to(self.config.device)
283+
return embeddings_np
267284

268285
self._model = self._load_model()
269286

@@ -279,17 +296,33 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
279296
if self.config.tokenizer_config.max_length is not None:
280297
self._model.max_seq_length = self.config.tokenizer_config.max_length
281298

282-
embeddings = self._model.encode(
283-
utterances,
284-
convert_to_numpy=True,
285-
batch_size=self.config.batch_size,
286-
normalize_embeddings=True,
287-
prompt=prompt,
288-
)
299+
embeddings: npt.NDArray[np.float32] | torch.Tensor
300+
if return_tensors:
301+
embeddings = self._model.encode(
302+
utterances,
303+
convert_to_tensor=True,
304+
batch_size=self.config.batch_size,
305+
normalize_embeddings=True,
306+
prompt=prompt,
307+
)
308+
else:
309+
embeddings = cast(
310+
npt.NDArray[np.float32],
311+
self._model.encode(
312+
utterances,
313+
convert_to_numpy=True,
314+
batch_size=self.config.batch_size,
315+
normalize_embeddings=True,
316+
prompt=prompt,
317+
),
318+
)
289319

290320
if self.config.use_cache:
291321
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
292-
np.save(embeddings_path, embeddings)
322+
if isinstance(embeddings, torch.Tensor):
323+
np.save(embeddings_path, embeddings.cpu().numpy())
324+
else:
325+
np.save(embeddings_path, embeddings)
293326

294327
return embeddings
295328

src/autointent/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
CNNScorer,
2121
CrossEncoderDescriptionScorer,
2222
DNNCScorer,
23+
GCNScorer,
2324
KNNScorer,
2425
LinearScorer,
2526
LLMDescriptionScorer,
@@ -47,6 +48,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4748
[
4849
CatBoostScorer,
4950
DNNCScorer,
51+
GCNScorer,
5052
KNNScorer,
5153
LinearScorer,
5254
BiEncoderDescriptionScorer,

src/autointent/modules/scoring/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ._catboost import CatBoostScorer
33
from ._description import BiEncoderDescriptionScorer, CrossEncoderDescriptionScorer, LLMDescriptionScorer
44
from ._dnnc import DNNCScorer
5+
from ._gcn import GCNScorer
56
from ._knn import KNNScorer, RerankScorer
67
from ._linear import LinearScorer
78
from ._lora import BERTLoRAScorer
@@ -18,6 +19,7 @@
1819
"CatBoostScorer",
1920
"CrossEncoderDescriptionScorer",
2021
"DNNCScorer",
22+
"GCNScorer",
2123
"KNNScorer",
2224
"LLMDescriptionScorer",
2325
"LinearScorer",
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .gcn_scorer import GCNScorer
2+
3+
__all__ = ["GCNScorer"]

0 commit comments

Comments
 (0)