Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions torch_geometric/loader/dynamic_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ class DynamicBatchSampler(torch.utils.data.sampler.Sampler):

Args:
dataset (Dataset): Dataset to sample from.
max_num (int): Size of mini-batch to aim for in number of nodes or
edges.
mode (str, optional): :obj:`"node"` or :obj:`"edge"` to measure
batch size. (default: :obj:`"node"`)
max_num_nodes (int): Size of mini-batch to aim for in number of nodes.
max_num_edges (int): Size of mini-batch to aim for in number of edges.
shuffle (bool, optional): If set to :obj:`True`, will have the data
reshuffled at every epoch. (default: :obj:`False`)
skip_too_big (bool, optional): If set to :obj:`True`, skip samples
Expand All @@ -42,22 +40,21 @@ class DynamicBatchSampler(torch.utils.data.sampler.Sampler):
def __init__(
self,
dataset: Dataset,
max_num: int,
mode: str = 'node',
max_num_nodes: int,
max_num_edges: int,
shuffle: bool = False,
skip_too_big: bool = False,
num_steps: Optional[int] = None,
):
if max_num <= 0:
raise ValueError(f"`max_num` should be a positive integer value "
f"(got {max_num})")
if mode not in ['node', 'edge']:
raise ValueError(f"`mode` choice should be either "
f"'node' or 'edge' (got '{mode}')")

if max_num_nodes <= 0:
raise ValueError(f"`max_num_nodes` should be a positive integer "
f"value (got {max_num_nodes})")
if max_num_edges <= 0:
raise ValueError(f"`max_num_edges` should be a positive integer "
f"value (got {max_num_edges})")
self.dataset = dataset
self.max_num = max_num
self.mode = mode
self.max_num_nodes = max_num_nodes
self.max_num_edges = max_num_edges
self.shuffle = shuffle
self.skip_too_big = skip_too_big
self.num_steps = num_steps
Expand All @@ -70,7 +67,8 @@ def __iter__(self) -> Iterator[List[int]]:
indices = range(len(self.dataset))

samples: List[int] = []
current_num: int = 0
current_num_nodes: int = 0
current_num_edges: int = 0
num_steps: int = 0
num_processed: int = 0

Expand All @@ -79,22 +77,29 @@ def __iter__(self) -> Iterator[List[int]]:

for i in indices[num_processed:]:
data = self.dataset[i]
num = data.num_nodes if self.mode == 'node' else data.num_edges

if current_num + num > self.max_num:
if current_num == 0:
num_nodes = data.num_nodes
num_edges = data.num_edges
# Check if adding a single graph will cause the mini-batch to
# more nodes/edges than the maximums given.
if (current_num_nodes + num_nodes
> self.max_num_nodes) or (current_num_edges + num_edges
> self.max_num_edges):

if current_num_nodes == 0 or current_num_edges == 0:
if self.skip_too_big:
continue
else: # Mini-batch filled:
break

samples.append(i)
num_processed += 1
current_num += num
current_num_nodes += num_nodes
current_num_edges += num_edges

yield samples
samples: List[int] = []
current_num = 0
current_num_nodes = 0
current_num_edges = 0
num_steps += 1

def __len__(self) -> int:
Expand Down