@@ -54,7 +54,7 @@ def on_finish(self) -> None:
5454 return
5555
5656
57- class IndexEncoder (PropertyEncoder ):
57+ class IndexEncoder (PropertyEncoder , abc . ABC ):
5858 """
5959 Encodes property values as indices. For that purpose, compiles a dynamic list of different values that have
6060 occurred. Stores this list in a file for later reference.
@@ -148,11 +148,11 @@ def encode(self, token: str | None) -> torch.Tensor:
148148 """
149149 if token is None :
150150 self ._count_for_unk_token += 1
151- return torch .tensor ([self ._unk_token_idx ])
151+ return torch .tensor ([self ._unk_token_idx ], dtype = torch . float32 )
152152
153153 if str (token ) not in self .cache :
154154 self .cache [str (token )] = len (self .cache )
155- return torch .tensor ([self .cache [str (token )] + self .offset ])
155+ return torch .tensor ([self .cache [str (token )] + self .offset ], dtype = torch . float32 )
156156
157157
158158class OneHotEncoder (IndexEncoder ):
@@ -215,11 +215,11 @@ def encode(self, token: str | None) -> torch.Tensor:
215215 """
216216 if token not in self .tokens_dict :
217217 self ._count_for_unk_token += 1
218- return torch .zeros (1 , self .get_encoding_length (), dtype = torch .int64 )
218+ return torch .zeros (1 , self .get_encoding_length (), dtype = torch .float32 )
219219
220220 return torch .nn .functional .one_hot (
221221 self .tokens_dict [token ], num_classes = self .get_encoding_length ()
222- )
222+ ). to ( dtype = torch . float32 )
223223
224224
225225class AsIsEncoder (PropertyEncoder ):
@@ -243,8 +243,8 @@ def encode(self, token: float | int | None) -> torch.Tensor:
243243 Tensor of shape (1,) containing the input value or zero.
244244 """
245245 if token is None :
246- return torch .tensor ([0 ])
247- return torch .tensor ([token ])
246+ return torch .tensor ([0 ], dtype = torch . float32 )
247+ return torch .tensor ([token ], dtype = torch . float32 )
248248
249249
250250class BoolEncoder (PropertyEncoder ):
@@ -267,4 +267,4 @@ def encode(self, token: bool) -> torch.Tensor:
267267 Returns:
268268 Tensor with 1 if True else 0.
269269 """
270- return torch .tensor ([1 if token else 0 ])
270+ return torch .tensor ([1 if token else 0 ], dtype = torch . float32 )
0 commit comments