1313from autointent .configs import VocabConfig
1414
1515
16- class BaseTorchModuleWithVocab (nn .Module , ABC ):
16+ class BaseTorchModule (nn .Module , ABC ):
17+ @abstractmethod
18+ def forward (self , text : torch .Tensor ) -> torch .Tensor :
19+ """Compute sentence embeddings for given text.
20+
21+ Args:
22+ text: torch tensor of shape (B, T), token ids
23+
24+ Returns:
25+ embeddings of shape (B, H)
26+ """
27+
28+ @abstractmethod
29+ def dump (self , path : Path ) -> None :
30+ """Dump torch module to disk.
31+
32+ This method encapsulates all the logic of dumping module's weights and
33+ hyperparameters required for initialization from disk and nice inference.
34+
35+ Args:
36+ path: path in file system
37+ """
38+
39+ @classmethod
40+ @abstractmethod
41+ def load (cls , path : Path , device : str | None = None ) -> Self :
42+ """Load torch module from disk.
43+
44+ This method loads all weights and hyperparameters required for
45+ initialization from disk and inference.
46+
47+ Args:
48+ path: path in file system
49+ device: torch notation for CPU, CUDA, MPS, etc. By default, it is inferred automatically.
50+ """
51+
52+ @property
53+ def device (self ) -> torch .device :
54+ """Torch device object where this module resides."""
55+ return next (self .parameters ()).device
56+
57+
58+ class BaseTorchModuleWithVocab (BaseTorchModule , ABC ):
1759 def __init__ (
1860 self ,
19- embed_dim : int ,
61+ embed_dim : int | None = None ,
2062 vocab_config : VocabConfig | None = None ,
2163 ) -> None :
2264 super ().__init__ ()
@@ -34,6 +76,9 @@ def __init__(
3476
3577 def set_vocab (self , vocab : dict [str , Any ]) -> None :
3678 """Save vocabulary into module's attributes and initialize embeddings matrix."""
79+ if self .embed_dim is None :
80+ msg = "embed_dim must be set to initialize embeddings"
81+ raise ValueError (msg )
3782 self .vocab_config .vocab = vocab
3883 self .embedding = nn .Embedding (
3984 num_embeddings = len (self .vocab_config .vocab ),
@@ -43,6 +88,10 @@ def set_vocab(self, vocab: dict[str, Any]) -> None:
4388
4489 def build_vocab (self , utterances : list [str ]) -> None :
4590 """Build vocabulary from training utterances."""
91+ if self .embed_dim is None :
92+ msg = "embed_dim must be set to initialize embeddings"
93+ raise ValueError (msg )
94+
4695 if self .vocab_config .vocab is not None :
4796 msg = "Vocab is already built."
4897 raise RuntimeError (msg )
@@ -80,43 +129,3 @@ def text_to_indices(self, utterances: list[str]) -> list[list[int]]:
80129 seq = seq + [self .vocab_config .padding_idx ] * (self .vocab_config .max_seq_length - len (seq ))
81130 sequences .append (seq )
82131 return sequences
83-
84- @abstractmethod
85- def forward (self , text : torch .Tensor ) -> torch .Tensor :
86- """Compute sentence embeddings for given text.
87-
88- Args:
89- text: torch tensor of shape (B, T), token ids
90-
91- Returns:
92- embeddings of shape (B, H)
93- """
94-
95- @abstractmethod
96- def dump (self , path : Path ) -> None :
97- """Dump torch module to disk.
98-
99- This method encapsulates all the logic of dumping module's weights and
100- hyperparameters required for initialization from disk and nice inference.
101-
102- Args:
103- path: path in file system
104- """
105-
106- @classmethod
107- @abstractmethod
108- def load (cls , path : Path , device : str | None = None ) -> Self :
109- """Load torch module from disk.
110-
111- This method loads all weights and hyperparameters required for
112- initialization from disk and inference.
113-
114- Args:
115- path: path in file system
116- device: torch notation for CPU, CUDA, MPS, etc. By default, it is inferred automatically.
117- """
118-
119- @property
120- def device (self ) -> torch .device :
121- """Torch device object where this module resides."""
122- return next (self .parameters ()).device
0 commit comments