1616
1717
1818class FinetunableStaticModel (nn .Module ):
19- def __init__ (self , * , vectors : torch .Tensor , tokenizer : Tokenizer , out_dim : int = 2 , pad_id : int = 0 ) -> None :
19+ def __init__ (self , * , vectors : torch .Tensor , tokenizer : Tokenizer , out_dim : int = 2 , pad_id : int = 0 , token_mapping : list [ int ] | None = None ) -> None :
2020 """
2121 Initialize a trainable StaticModel from a StaticModel.
2222
@@ -38,14 +38,19 @@ def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int
3838 )
3939 self .vectors = vectors .float ()
4040
41+ if token_mapping is not None :
42+ self .token_mapping = torch .tensor (token_mapping , dtype = torch .int64 )
43+ else :
44+ self .token_mapping = torch .arange (len (vectors ), dtype = torch .int64 )
45+ self .token_mapping = nn .Parameter (self .token_mapping , requires_grad = False )
4146 self .embeddings = nn .Embedding .from_pretrained (vectors .clone (), freeze = False , padding_idx = pad_id )
4247 self .head = self .construct_head ()
4348 self .w = self .construct_weights ()
4449 self .tokenizer = tokenizer
4550
4651 def construct_weights (self ) -> nn .Parameter :
4752 """Construct the weights for the model."""
48- weights = torch .zeros (len (self .vectors ))
53+ weights = torch .zeros (len (self .token_mapping ))
4954 weights [self .pad_id ] = - 10_000
5055 return nn .Parameter (weights )
5156
@@ -66,11 +71,16 @@ def from_static_model(cls: type[ModelType], *, model: StaticModel, out_dim: int
6671 """Load the model from a static model."""
6772 model .embedding = np .nan_to_num (model .embedding )
6873 embeddings_converted = torch .from_numpy (model .embedding )
74+ if model .token_mapping is not None :
75+ token_mapping = [i for _ , i in sorted (model .token_mapping .items (), key = lambda x : x [0 ])]
76+ else :
77+ token_mapping = None
6978 return cls (
7079 vectors = embeddings_converted ,
7180 pad_id = model .tokenizer .token_to_id ("[PAD]" ),
7281 out_dim = out_dim ,
7382 tokenizer = model .tokenizer ,
83+ token_mapping = token_mapping ,
7484 ** kwargs ,
7585 )
7686
@@ -90,7 +100,8 @@ def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
90100 w = w * zeros
91101 # Add a small epsilon to avoid division by zero
92102 length = zeros .sum (1 ) + 1e-16
93- embedded = self .embeddings (input_ids )
103+ input_ids_embeddings = self .token_mapping [input_ids ]
104+ embedded = self .embeddings (input_ids_embeddings )
94105 # Weigh each token
95106 embedded = torch .bmm (w [:, None , :], embedded ).squeeze (1 )
96107 # Mean pooling by dividing by the length
@@ -118,16 +129,17 @@ def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tens
118129 return pad_sequence (encoded_ids , batch_first = True , padding_value = self .pad_id )
119130
120131 @property
121- def device (self ) -> str :
132+ def device (self ) -> torch . device :
122133 """Get the device of the model."""
123134 return self .embeddings .weight .device
124135
125136 def to_static_model (self ) -> StaticModel :
126137 """Convert the model to a static model."""
127138 emb = self .embeddings .weight .detach ().cpu ().numpy ()
128139 w = torch .sigmoid (self .w ).detach ().cpu ().numpy ()
140+ token_mapping = {i : int (token_id ) for i , token_id in enumerate (self .token_mapping .tolist ())}
129141
130- return StaticModel (emb * w [:, None ], self .tokenizer , normalize = True )
142+ return StaticModel (vectors = emb , weights = w , tokenizer = self .tokenizer , normalize = True , token_mapping = token_mapping )
131143
132144
133145class TextDataset (Dataset ):
0 commit comments