Skip to content

Commit 4a1e82b

Browse files
committed
fix mol props for as per node data cls
1 parent 84c170b commit 4a1e82b

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -376,14 +376,14 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
376376
n_atom_node_properties, n_fg_node_properties, n_graph_node_properties
377377
)
378378
rank_zero_info(
379-
f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n"
380-
f"Properties Categories:\n{pformat(props_categories)}"
379+
f"\nFinished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n\n"
380+
f"Properties Categories:\n{pformat(props_categories)}\n\n"
381381
f"n_atom_node_properties: {n_atom_node_properties}, "
382382
f"n_fg_node_properties: {n_fg_node_properties}, "
383383
f"n_bond_properties: {n_bond_properties}, "
384-
f"n_graph_node_properties: {n_graph_node_properties}\n"
384+
f"n_graph_node_properties: {n_graph_node_properties}\n\n"
385385
f"Use following values for given parameters for model configuration: \n\t"
386-
f"in_channels: {n_node_properties}, edge_dim: {n_bond_properties}, n_molecule_properties: 0"
386+
f"in_channels: {n_node_properties}, edge_dim: {n_bond_properties}, n_molecule_properties: 0\n"
387387
)
388388

389389
for property in self.properties:
@@ -449,7 +449,7 @@ def _merge_props_into_base(
449449
atom_offset, fg_offset, graph_offset = 0, 0, 0
450450

451451
for property in self.properties:
452-
property_values = row[f"{property.name}"]
452+
property_values = row[f"{property.name}"].to(dtype=torch.float32)
453453
if isinstance(property_values, torch.Tensor):
454454
if len(property_values.size()) == 0:
455455
property_values = property_values.unsqueeze(0)
@@ -482,7 +482,7 @@ def _merge_props_into_base(
482482

483483
elif isinstance(property, MoleculeProperty):
484484
x[is_graph_node, graph_offset : graph_offset + enc_len] = (
485-
property_values[is_graph_node]
485+
property_values
486486
)
487487
graph_offset += enc_len
488488

@@ -505,6 +505,7 @@ def _merge_props_into_base(
505505
x=x,
506506
edge_index=geom_data.edge_index,
507507
edge_attr=edge_attr,
508+
molecule_attr=torch.empty((1, 0)), # empty as not used for this class
508509
)
509510

510511

chebai_graph/preprocessing/properties/properties.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,12 @@ class RDKit2DNormalized(MoleculeProperty):
275275
def __init__(self, encoder: PropertyEncoder | None = None) -> None:
276276
super().__init__(encoder or AsIsEncoder(self))
277277
self.generator_normalized = rdNormalizedDescriptors.RDKit2DNormalized()
278+
# Create a dummy molecule (e.g., methane) to extract the length of descriptor vector
279+
dummy_mol = Chem.MolFromSmiles("C")
280+
descr_values = self.generator_normalized.processMol(
281+
dummy_mol, Chem.MolToSmiles(dummy_mol)
282+
)
283+
self.encoder.set_encoding_length(len(descr_values) - 1)
278284

279285
def get_property_value(self, mol: Chem.rdchem.Mol) -> list[np.ndarray]:
280286
"""

chebai_graph/preprocessing/property_encoder.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,16 @@ def encode(self, token: float | int | None) -> torch.Tensor:
243243
Tensor of shape (1,) containing the input value or zero.
244244
"""
245245
if token is None:
246-
return torch.tensor([0])
247-
return torch.tensor([token])
246+
return torch.zeros(1, self.get_encoding_length())
247+
assert (
248+
len(token) == self.get_encoding_length()
249+
), "Length of token should be equal to encoding length"
250+
# return torch.tensor([token]) # token is an ndarray, no need to create list of ndarray due to below warning
251+
# UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow.
252+
# Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor.
253+
# (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\utils\tensor_new.cpp:257.)
254+
# ----- fix: for above warning
255+
return torch.tensor(token).unsqueeze(0) # shape: (1, len(token))
248256

249257

250258
class BoolEncoder(PropertyEncoder):

0 commit comments

Comments
 (0)