Skip to content

Commit 9c6d915

Browse files
committed
fix merge conflict
2 parents a7915ec + 1ba24e0 commit 9c6d915

File tree

4 files changed

+75
-106
lines changed

4 files changed

+75
-106
lines changed

chebai_graph/models/base.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,14 @@ class GraphNetWrapper(GraphBaseNet, ABC):
7474
"""
7575

7676
def __init__(
77-
self, config: dict, n_linear_layers: int, n_molecule_properties: Optional[int] = 0, use_batch_norm: bool = False, **kwargs
78-
) -> None:
77+
self,
78+
config: dict,
79+
n_linear_layers: int,
80+
n_molecule_properties: Optional[int] = 0,
81+
use_batch_norm: bool = False,
82+
**kwargs,
83+
):
7984
"""
80-
Initialize the GNN and linear layers.
81-
8285
Args:
8386
config (dict): Model configuration.
8487
n_linear_layers (int): Number of linear layers.
@@ -91,7 +94,9 @@ def __init__(
9194
self.activation = torch.nn.ELU
9295
self.lin_input_dim = self._get_lin_seq_input_dim(
9396
gnn_out_dim=gnn_out_dim,
94-
n_molecule_properties=n_molecule_properties if n_molecule_properties is not None else 0,
97+
n_molecule_properties=(
98+
n_molecule_properties if n_molecule_properties is not None else 0
99+
),
95100
)
96101
self.use_batch_norm = use_batch_norm
97102
if self.use_batch_norm:

chebai_graph/models/dynamic_gni.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,43 @@ def __init__(self, config: dict[str, Any], **kwargs: Any):
3333
)
3434
self.distribution = distribution
3535

36-
self.complete_randomness = config.get("complete_randomness", True)
36+
self.complete_randomness = (
37+
str(config.get("complete_randomness", "True")).lower() == "true"
38+
)
39+
40+
print("Using complete randomness: ", self.complete_randomness)
3741

3842
if not self.complete_randomness:
3943
assert (
40-
"random_pad_node" in config or "random_pad_edge" in config
41-
), "Missing 'random_pad_node' or 'random_pad_edge' in config when complete_randomness is False"
42-
self.random_pad_node = (
43-
int(config["random_pad_node"])
44-
if config.get("random_pad_node") is not None
44+
"pad_node_features" in config or "pad_edge_features" in config
45+
), "Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False"
46+
self.pad_node_features = (
47+
int(config["pad_node_features"])
48+
if config.get("pad_node_features") is not None
4549
else None
4650
)
47-
self.random_pad_edge = (
48-
int(config["random_pad_edge"])
49-
if config.get("random_pad_edge") is not None
51+
if self.pad_node_features is not None:
52+
print(
53+
f"[Info] Node features will be padded with {self.pad_node_features} "
54+
f"new set of random features from distribution {self.distribution} "
55+
f"in each forward pass."
56+
)
57+
58+
self.pad_edge_features = (
59+
int(config["pad_edge_features"])
60+
if config.get("pad_edge_features") is not None
5061
else None
5162
)
63+
if self.pad_edge_features is not None:
64+
print(
65+
f"[Info] Edge features will be padded with {self.pad_edge_features} "
66+
f"new set of random features from distribution {self.distribution} "
67+
f"in each forward pass."
68+
)
69+
5270
assert (
53-
self.random_pad_node > 0 or self.random_pad_edge > 0
54-
), "'random_pad_node' or 'random_pad_edge' must be positive integers"
71+
self.pad_node_features > 0 or self.pad_edge_features > 0
72+
), "'pad_node_features' or 'pad_edge_features' must be positive integers"
5573

5674
self.resgated: BasicGNN = ResGatedModel(
5775
in_channels=self.in_channels,
@@ -93,21 +111,21 @@ def forward(self, batch: dict[str, Any]) -> Tensor:
93111
new_edge_attr, self.distribution
94112
)
95113
else:
96-
if self.random_pad_node is not None:
114+
if self.pad_node_features is not None:
97115
pad_node = torch.empty(
98116
graph_data.x.shape[0],
99-
self.random_pad_node,
117+
self.pad_node_features,
100118
device=self.device,
101119
)
102120
RandomFeatureInitializationReader.random_gni(
103121
pad_node, self.distribution
104122
)
105123
new_x = torch.cat((graph_data.x, pad_node), dim=1)
106124

107-
if self.random_pad_edge is not None:
125+
if self.pad_edge_features is not None:
108126
pad_edge = torch.empty(
109127
graph_data.edge_attr.shape[0],
110-
self.random_pad_edge,
128+
self.pad_edge_features,
111129
device=self.device,
112130
)
113131
RandomFeatureInitializationReader.random_gni(

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 29 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -189,45 +189,29 @@ def __init__(
189189
self,
190190
properties=None,
191191
transform=None,
192-
zero_pad_node: int = None,
193-
zero_pad_edge: int = None,
194-
random_pad_node: int = None,
195-
random_pad_edge: int = None,
192+
pad_node_features: int = None,
193+
pad_edge_features: int = None,
196194
distribution: str = "normal",
197195
**kwargs,
198196
):
199197
super().__init__(properties, transform, **kwargs)
200-
self.zero_pad_node = int(zero_pad_node) if zero_pad_node else None
201-
if self.zero_pad_node:
202-
print(
203-
f"[Info] Node-level features will be zero-padded with "
204-
f"{self.zero_pad_node} additional dimensions."
205-
)
206-
207-
self.zero_pad_edge = int(zero_pad_edge) if zero_pad_edge else None
208-
if self.zero_pad_edge:
209-
print(
210-
f"[Info] Edge-level features will be zero-padded with "
211-
f"{self.zero_pad_edge} additional dimensions."
212-
)
213-
214-
self.random_pad_edge = int(random_pad_edge) if random_pad_edge else None
215-
self.random_pad_node = int(random_pad_node) if random_pad_node else None
216-
if self.random_pad_node or self.random_pad_edge:
198+
self.pad_edge_features = int(pad_edge_features) if pad_edge_features else None
199+
self.pad_node_features = int(pad_node_features) if pad_node_features else None
200+
if self.pad_node_features or self.pad_edge_features:
217201
assert (
218202
distribution is not None
219203
and distribution in RandomFeatureInitializationReader.DISTRIBUTIONS
220-
), "When using random padding, a valid distribution must be specified."
204+
), "When using padding for features, a valid distribution must be specified."
221205
self.distribution = distribution
222-
if self.random_pad_node:
206+
if self.pad_node_features:
223207
print(
224-
f"[Info] Node-level features will be padded with "
225-
f"{self.random_pad_node} additional dimensions initialized from {self.distribution} distribution."
208+
f"[Info] Node-level features will be padded with random"
209+
f"{self.pad_node_features} values from {self.distribution} distribution."
226210
)
227-
if self.random_pad_edge:
211+
if self.pad_edge_features:
228212
print(
229-
f"[Info] Edge-level features will be padded with "
230-
f"{self.random_pad_edge} additional dimensions initialized from {self.distribution} distribution."
213+
f"[Info] Edge-level features will be padded with random"
214+
f"{self.pad_edge_features} values from {self.distribution} distribution."
231215
)
232216

233217
if self.properties:
@@ -276,24 +260,19 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
276260
else:
277261
raise TypeError(f"Unsupported property type: {type(property).__name__}")
278262

279-
if self.zero_pad_node:
280-
x = torch.cat([x, torch.zeros((x.shape[0], self.zero_pad_node))], dim=1)
281-
282-
if self.zero_pad_edge:
283-
edge_attr = torch.cat(
284-
[edge_attr, torch.zeros((edge_attr.shape[0], self.zero_pad_edge))],
285-
dim=1,
263+
if self.pad_node_features:
264+
padding_values = torch.empty((x.shape[0], self.pad_node_features))
265+
RandomFeatureInitializationReader.random_gni(
266+
padding_values, self.distribution
286267
)
268+
x = torch.cat([x, padding_values], dim=1)
287269

288-
if self.random_pad_node:
289-
random_pad = torch.empty((x.shape[0], self.random_pad_node))
290-
RandomFeatureInitializationReader.random_gni(random_pad, self.distribution)
291-
x = torch.cat([x, random_pad], dim=1)
292-
293-
if self.random_pad_edge:
294-
random_pad = torch.empty((edge_attr.shape[0], self.random_pad_edge))
295-
RandomFeatureInitializationReader.random_gni(random_pad, self.distribution)
296-
edge_attr = torch.cat([edge_attr, random_pad], dim=1)
270+
if self.pad_edge_features:
271+
padding_values = torch.empty((edge_attr.shape[0], self.pad_edge_features))
272+
RandomFeatureInitializationReader.random_gni(
273+
padding_values, self.distribution
274+
)
275+
edge_attr = torch.cat([edge_attr, padding_values], dim=1)
297276

298277
return GeomData(
299278
x=x,
@@ -350,13 +329,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
350329
)
351330

352331
in_channels_str = ""
353-
if self.zero_pad_node:
354-
n_node_properties += self.zero_pad_node
355-
in_channels_str += f" (with {self.zero_pad_node} padded zeros)"
356-
357-
if self.random_pad_node:
358-
n_node_properties += self.random_pad_node
359-
in_channels_str += f" (with {self.random_pad_node} random padded values from {self.distribution} distribution)"
332+
if self.pad_node_features:
333+
n_node_properties += self.pad_node_features
334+
in_channels_str += f" (with {self.pad_node_features} padded random values from {self.distribution} distribution)"
360335

361336
in_channels_str = f"in_channels: {n_node_properties}" + in_channels_str
362337

@@ -367,14 +342,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
367342
if isinstance(p, BondProperty)
368343
)
369344
edge_dim_str = ""
370-
371-
if self.zero_pad_edge:
372-
n_edge_properties += self.zero_pad_edge
373-
edge_dim_str += f" (with {self.zero_pad_edge} padded zeros)"
374-
375-
if self.random_pad_edge:
376-
n_edge_properties += self.random_pad_edge
377-
edge_dim_str += f" (with {self.random_pad_edge} random padded values from {self.distribution} distribution)"
345+
if self.pad_edge_features:
346+
n_edge_properties += self.pad_edge_features
347+
edge_dim_str += f" (with {self.pad_edge_features} padded random values from {self.distribution} distribution)"
378348

379349
edge_dim_str = f"edge_dim: {n_edge_properties}" + edge_dim_str
380350

@@ -388,32 +358,6 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
388358

389359
return base_df[base_data[0].keys()].to_dict("records")
390360

391-
@property
392-
def processed_file_names_dict(self) -> dict:
393-
"""
394-
Returns a dictionary for the processed and tokenized data files.
395-
396-
Returns:
397-
dict: A dictionary mapping dataset keys to their respective file names.
398-
For example, {"data": "data.pt"}.
399-
"""
400-
if self.n_token_limit is not None:
401-
return {"data": f"data_maxlen{self.n_token_limit}.pt"}
402-
403-
data_pt_filename = "data"
404-
if self.zero_pad_node:
405-
data_pt_filename += f"_zpn{self.zero_pad_node}"
406-
if self.zero_pad_edge:
407-
data_pt_filename += f"_zpe{self.zero_pad_edge}"
408-
if self.random_pad_node:
409-
data_pt_filename += f"_rpn{self.random_pad_node}"
410-
if self.random_pad_edge:
411-
data_pt_filename += f"_rpe{self.random_pad_edge}"
412-
if self.random_pad_node or self.random_pad_edge:
413-
data_pt_filename += f"_D{self.distribution}"
414-
415-
return {"data": data_pt_filename + ".pt"}
416-
417361

418362
class GraphPropAsPerNodeType(DataPropertiesSetter, ABC):
419363
def __init__(self, properties=None, transform=None, **kwargs):

chebai_graph/preprocessing/reader/static_gni.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
class RandomFeatureInitializationReader(GraphPropertyReader):
16-
DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform"]
16+
DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"]
1717

1818
def __init__(
1919
self,
@@ -74,5 +74,7 @@ def random_gni(tensor: torch.Tensor, distribution: str) -> None:
7474
torch.nn.init.xavier_normal_(tensor)
7575
elif distribution == "xavier_uniform":
7676
torch.nn.init.xavier_uniform_(tensor)
77+
elif distribution == "zeros":
78+
torch.nn.init.zeros_(tensor)
7779
else:
7880
raise ValueError("Unknown distribution type")

0 commit comments

Comments
 (0)