-
Notifications
You must be signed in to change notification settings - Fork 6
Description
Description:
I’ve encountered a bug related to the data loading flow when load_processed_data is overridden in the GraphPropertiesMixIn class in python-chebai-graph repository, whose subclasses inherits from _DynamicDataset (a subclass of XYBaseDataModule).
Context
In _DynamicDataset, we introduced dynamic data splits with the ability to persist them into a splits.csv file. This allows us to load the same splits in future runs by passing the splits_file_path to the data class.
- The
load_processed_data(kind=...)method handles loading the relevant partition of data (train/val/test) by accessing thedynamic_splits_dfproperty. - If a
filenameis provided instead ofkind, the method loads the serialized data object usingtorch.load(filename).
Normal stack flow with splits_file_path:
val_dataloader
└── dataloader
└── load_processed_data(kind="val")
└── dynamic_splits_df
└── _retrieve_splits_from_csv
└── load_processed_data(filename="data.pt")
└── torch.load(...)
This flow works as expected in the base _DynamicDataset class.
Problem
When load_processed_data is overridden in a subclass (e.g., GraphDataset), it first calls the superclass method and then applies additional logic. This causes an unexpected recursive call pattern:
val_dataloader
└── dataloader
└── GraphDataset.load_processed_data(kind="val")
└── _DynamicDataset.load_processed_data(kind="val")
└── dynamic_splits_df
└── _retrieve_splits_from_csv
└── GraphDataset.load_processed_data(filename="data.pt") ❌
└── _DynamicDataset.load_processed_data(filename="data.pt")
└── torch.load(...)
⮑ returns to GraphDataset.load_processed_data (1st call)
└── executes GraphDataset logic
⮑ returns to GraphDataset.load_processed_data (original call)
└── executes GraphDataset logic **again**
This results in GraphDataset’s logic being executed twice, leading to incorrect behavior and eventually triggering the following error:
File "G:\github-aditya0by0\python-chebai-graph\chebai_graph\preprocessing\datasets\chebi.py", line 198, in load_processed_data
base_df = base_df.merge(property_df, on="ident", how="left")
File "G:\anaconda3\envs\gnn3\lib\site-packages\pandas\core\frame.py", line 10832, in merge
return merge(
File "G:\anaconda3\envs\gnn3\lib\site-packages\pandas\core\reshape\merge.py", line 170, in merge
op = _MergeOperation(
File "G:\anaconda3\envs\gnn3\lib\site-packages\pandas\core\reshape\merge.py", line 794, in __init__
) = self._get_merge_keys()
File "G:\anaconda3\envs\gnn3\lib\site-packages\pandas\core\reshape\merge.py", line 1310, in _get_merge_keys
left_keys.append(left._get_label_or_level_values(lk))
File "G:\anaconda3\envs\gnn3\lib\site-packages\pandas\core\generic.py", line 1911, in _get_label_or_level_values
raise KeyError(key)
KeyError: 'ident'
The above error is because of logic being executed twice.
Proposed Fix
-
Option 1: Avoid calling
load_processed_datarecursively inside_retrieve_splits_from_csv; instead usetorch.load
I opine this is the simplest and easy to implement option, other options can be bit more time-consuming and bit complex -
Option 2: Separate the logic for load_processed_data into two internal methods:
_load_partitioned_data(kind=...)
_load_file_data(filename=...)Then call the appropriate method internally to avoid recursion.