diff --git a/torch_geometric/loader/dynamic_batch_sampler.py b/torch_geometric/loader/dynamic_batch_sampler.py index 6df8f14328fc..cbdf0ef0e1e4 100644 --- a/torch_geometric/loader/dynamic_batch_sampler.py +++ b/torch_geometric/loader/dynamic_batch_sampler.py @@ -26,10 +26,8 @@ class DynamicBatchSampler(torch.utils.data.sampler.Sampler): Args: dataset (Dataset): Dataset to sample from. - max_num (int): Size of mini-batch to aim for in number of nodes or - edges. - mode (str, optional): :obj:`"node"` or :obj:`"edge"` to measure - batch size. (default: :obj:`"node"`) + max_num_nodes (int): Size of mini-batch to aim for in number of nodes. + max_num_edges (int): Size of mini-batch to aim for in number of edges. shuffle (bool, optional): If set to :obj:`True`, will have the data reshuffled at every epoch. (default: :obj:`False`) skip_too_big (bool, optional): If set to :obj:`True`, skip samples @@ -42,22 +40,21 @@ class DynamicBatchSampler(torch.utils.data.sampler.Sampler): def __init__( self, dataset: Dataset, - max_num: int, - mode: str = 'node', + max_num_nodes: int, + max_num_edges: int, shuffle: bool = False, skip_too_big: bool = False, num_steps: Optional[int] = None, ): - if max_num <= 0: - raise ValueError(f"`max_num` should be a positive integer value " - f"(got {max_num})") - if mode not in ['node', 'edge']: - raise ValueError(f"`mode` choice should be either " - f"'node' or 'edge' (got '{mode}')") - + if max_num_nodes <= 0: + raise ValueError(f"`max_num_nodes` should be a positive integer " + f"value (got {max_num_nodes})") + if max_num_edges <= 0: + raise ValueError(f"`max_num_edges` should be a positive integer " + f"value (got {max_num_edges})") self.dataset = dataset - self.max_num = max_num - self.mode = mode + self.max_num_nodes = max_num_nodes + self.max_num_edges = max_num_edges self.shuffle = shuffle self.skip_too_big = skip_too_big self.num_steps = num_steps @@ -70,7 +67,8 @@ def __iter__(self) -> Iterator[List[int]]: indices = range(len(self.dataset)) samples: List[int] = [] - current_num: int = 0 + current_num_nodes: int = 0 + current_num_edges: int = 0 num_steps: int = 0 num_processed: int = 0 @@ -79,10 +77,15 @@ def __iter__(self) -> Iterator[List[int]]: for i in indices[num_processed:]: data = self.dataset[i] - num = data.num_nodes if self.mode == 'node' else data.num_edges - - if current_num + num > self.max_num: - if current_num == 0: + num_nodes = data.num_nodes + num_edges = data.num_edges + # Check if adding a single graph will cause the mini-batch to + # more nodes/edges than the maximums given. + if (current_num_nodes + num_nodes + > self.max_num_nodes) or (current_num_edges + num_edges + > self.max_num_edges): + + if current_num_nodes == 0 or current_num_edges == 0: if self.skip_too_big: continue else: # Mini-batch filled: @@ -90,11 +93,13 @@ def __iter__(self) -> Iterator[List[int]]: samples.append(i) num_processed += 1 - current_num += num + current_num_nodes += num_nodes + current_num_edges += num_edges yield samples samples: List[int] = [] - current_num = 0 + current_num_nodes = 0 + current_num_edges = 0 num_steps += 1 def __len__(self) -> int: