Skip to content

Commit bed7ebe

Browse files
committed
explicit float32 tensors for property encoders
1 parent 84c170b commit bed7ebe

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def _merge_props_into_base(
443443
assert (
444444
max_len_node_properties is not None
445445
), "Maximum len of node properties should not be None"
446-
x = torch.zeros((num_nodes, max_len_node_properties))
446+
x = torch.zeros((num_nodes, max_len_node_properties), dtype=torch.float32)
447447

448448
# Track column offsets for each node type
449449
atom_offset, fg_offset, graph_offset = 0, 0, 0

chebai_graph/preprocessing/property_encoder.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def on_finish(self) -> None:
5454
return
5555

5656

57-
class IndexEncoder(PropertyEncoder):
57+
class IndexEncoder(PropertyEncoder, abc.ABC):
5858
"""
5959
Encodes property values as indices. For that purpose, compiles a dynamic list of different values that have
6060
occurred. Stores this list in a file for later reference.
@@ -148,11 +148,11 @@ def encode(self, token: str | None) -> torch.Tensor:
148148
"""
149149
if token is None:
150150
self._count_for_unk_token += 1
151-
return torch.tensor([self._unk_token_idx])
151+
return torch.tensor([self._unk_token_idx], dtype=torch.float32)
152152

153153
if str(token) not in self.cache:
154154
self.cache[str(token)] = len(self.cache)
155-
return torch.tensor([self.cache[str(token)] + self.offset])
155+
return torch.tensor([self.cache[str(token)] + self.offset], dtype=torch.float32)
156156

157157

158158
class OneHotEncoder(IndexEncoder):
@@ -215,11 +215,11 @@ def encode(self, token: str | None) -> torch.Tensor:
215215
"""
216216
if token not in self.tokens_dict:
217217
self._count_for_unk_token += 1
218-
return torch.zeros(1, self.get_encoding_length(), dtype=torch.int64)
218+
return torch.zeros(1, self.get_encoding_length(), dtype=torch.float32)
219219

220220
return torch.nn.functional.one_hot(
221221
self.tokens_dict[token], num_classes=self.get_encoding_length()
222-
)
222+
).to(dtype=torch.float32)
223223

224224

225225
class AsIsEncoder(PropertyEncoder):
@@ -243,8 +243,8 @@ def encode(self, token: float | int | None) -> torch.Tensor:
243243
Tensor of shape (1,) containing the input value or zero.
244244
"""
245245
if token is None:
246-
return torch.tensor([0])
247-
return torch.tensor([token])
246+
return torch.tensor([0], dtype=torch.float32)
247+
return torch.tensor([token], dtype=torch.float32)
248248

249249

250250
class BoolEncoder(PropertyEncoder):
@@ -267,4 +267,4 @@ def encode(self, token: bool) -> torch.Tensor:
267267
Returns:
268268
Tensor with 1 if True else 0.
269269
"""
270-
return torch.tensor([1 if token else 0])
270+
return torch.tensor([1 if token else 0], dtype=torch.float32)

0 commit comments

Comments
 (0)