44import torch
55import torch .nn as nn
66from pydantic import BaseModel
7+ from typing_extensions import Self
78
89from autointent ._utils import detect_device
10+ from autointent ._wrappers import BaseTorchModuleWithVocab
911
1012
1113class GCNModelDumpMetadata (BaseModel ):
@@ -28,7 +30,7 @@ def forward(self, adj_matrix, features):
2830 return output
2931
3032
31- class TextMLGCN (nn . Module ):
33+ class TextMLGCN (BaseTorchModuleWithVocab ):
3234 _metadata_dict_name = "metadata.json"
3335 _state_dict_name = "state_dict.pt"
3436
@@ -41,7 +43,7 @@ def __init__(
4143 p_reweight : float ,
4244 tau_threshold : float ,
4345 ):
44- super ().__init__ ()
46+ super ().__init__ (embed_dim = bert_feature_dim )
4547 self .num_classes = num_classes
4648 self .p_reweight = p_reweight
4749 self .tau_threshold = tau_threshold
@@ -93,7 +95,7 @@ def set_correlation_matrix(self, train_labels):
9395 )
9496 self .correlation_matrix .data .copy_ (corr_matrix )
9597
96- def forward (self , bert_features , label_embeddings ):
98+ def forward (self , bert_features , label_embeddings ): # type: ignore
9799 classifiers = label_embeddings
98100 for i in range (len (self .gcn_layers )):
99101 classifiers = self .gcn_layers [i ](self .correlation_matrix , classifiers )
@@ -102,10 +104,6 @@ def forward(self, bert_features, label_embeddings):
102104 logits = torch .matmul (bert_features , classifiers .T )
103105 return logits
104106
105- @property
106- def device (self ) -> torch .device :
107- return next (self .parameters ()).device
108-
109107 def dump (self , path : Path ) -> None :
110108 metadata = GCNModelDumpMetadata (
111109 num_classes = self .num_classes ,
@@ -124,7 +122,7 @@ def dump(self, path: Path) -> None:
124122 self .to (device )
125123
126124 @classmethod
127- def load (cls , path : Path , device : str | None = None ) -> "TextMLGCN" :
125+ def load (cls , path : Path , device : str | None = None ) -> Self :
128126 with (path / cls ._metadata_dict_name ).open () as file :
129127 metadata = GCNModelDumpMetadata (** json .load (file ))
130128 device = device or detect_device ()
0 commit comments