Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions examples/multi_gpu/ogbn_train_cugraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def arg_parse():
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--wd', type=float, default=0.000)
parser.add_argument('-e', '--epochs', type=int, default=50)
parser.add_argument('-le', '--local_epochs', type=int, default=50,
help='warmup epochs for polynormer')
parser.add_argument('-b', '--batch_size', type=int, default=1024)
parser.add_argument('--fan_out', type=int, default=10)
parser.add_argument('--warmup_steps', type=int, default=20)
Expand All @@ -87,6 +89,7 @@ def arg_parse():
'GCN',
# TODO: Uncomment when we add support for disjoint sampling
# 'SGFormer',
# 'Polynormer',
],
help="Model used for training, default GCN",
)
Expand All @@ -112,15 +115,19 @@ def arg_parse():
return args


def evaluate(rank, loader, model):
def evaluate(args, rank, loader, model):
with torch.no_grad():
total_correct = total_examples = 0
for batch in loader:
batch = batch.to(rank)
batch_size = batch.batch_size

batch.y = batch.y.to(torch.long)
out = model(batch.x, batch.edge_index)[:batch_size]
if args.model in ['SGFormer', 'Polynormer']:
out = model(batch.x, batch.edge_index,
batch.batch)[:batch_size]
else:
out = model(batch.x, batch.edge_index)[:batch_size]

pred = out.argmax(dim=-1)
y = batch.y[:batch_size].view(-1).to(torch.long)
Expand Down Expand Up @@ -161,6 +168,8 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx,
num_classes, wall_clock_start):

epochs = args.epochs
if args.model == 'Polynormer':
epochs += args.local_epochs
batch_size = args.batch_size
fan_out = args.fan_out
num_layers = args.num_layers
Expand All @@ -172,13 +181,15 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx,
)

model = model.to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
model = DistributedDataParallel(model, device_ids=[rank],
find_unused_parameters=True)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=args.wd)

kwargs = dict(
num_neighbors=[fan_out] * num_layers,
batch_size=batch_size,
disjoint=args.gnn_choice in ['SGFormer', 'Polynormer'],
)
from cugraph_pyg.data import GraphStore, TensorDictFeatureStore
from cugraph_pyg.loader import NeighborLoader
Expand Down Expand Up @@ -253,7 +264,13 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx,
batch_size = batch.batch_size
batch.y = batch.y.to(torch.long)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
if args.model in ['SGFormer', 'Polynormer']:
if args.model == 'Polynormer' and epoch == args.local_epochs: # noqa: E501
print('start global attention')
model.model._global = True
out = model(batch.x, batch.edge_index, batch.batch)
else:
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size])
loss.backward()
optimizer.step()
Expand All @@ -266,9 +283,9 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx,
torch.cuda.synchronize()

inference_start = time.perf_counter()
train_acc = evaluate(rank, train_loader, model)
train_acc = evaluate(args, rank, train_loader, model)
dist.barrier()
val_acc = evaluate(rank, val_loader, model)
val_acc = evaluate(args, rank, val_loader, model)
dist.barrier()

inference_times.append(time.perf_counter() - inference_start)
Expand Down Expand Up @@ -298,7 +315,7 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx,

if rank == 0:
print("Testing...")
final_test_acc = evaluate(rank, test_loader, model)
final_test_acc = evaluate(args, rank, test_loader, model)
dist.barrier()
if rank == 0:
print(f'Test Accuracy: {final_test_acc:.4f} for rank: {rank:02d}')
Expand All @@ -315,6 +332,11 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx,

args = arg_parse()
seed_everything(123)
if args.model == 'Polynormer' and args.num_layers != 7:
print(
"The original polynormer paper recommends 7 layers, you have "
"chosen", args.num_layers, "which may effect results. "
"See for details")
wall_clock_start = time.perf_counter()

root = osp.join(args.dataset_dir, args.dataset_subdir)
Expand All @@ -331,11 +353,13 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx,

print(f"Training {args.dataset} with {args.model} model.")
if args.model == "GAT":
model = torch_geometric.nn.models.GAT(dataset.num_features,
args.hidden_channels,
args.num_layers,
dataset.num_classes,
heads=args.num_heads)
model = torch_geometric.nn.models.GAT(
dataset.num_features,
args.hidden_channels,
args.num_layers,
dataset.num_classes,
heads=args.num_heads,
)
elif args.model == "GCN":
model = torch_geometric.nn.models.GCN(
dataset.num_features,
Expand All @@ -361,6 +385,14 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx,
gnn_num_layers=args.num_layers,
gnn_dropout=args.dropout,
)
elif args.model == 'Polynormer':
# TODO add support for this with disjoint sampling
model = torch_geometric.nn.models.Polynormer(
in_channels=dataset.num_features,
hidden_channels=args.hidden_channels,
out_channels=dataset.num_classes,
local_layers=args.num_layers,
)
else:
raise ValueError(f'Unsupported model type: {args.model}')

Expand Down
62 changes: 48 additions & 14 deletions examples/ogbn_train_cugraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def arg_parse():
help="directory of dataset.",
)
parser.add_argument('-e', '--epochs', type=int, default=50)
parser.add_argument('-le', '--local_epochs', type=int, default=50,
help='warmup epochs for polynormer')
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('-b', '--batch_size', type=int, default=1024)
parser.add_argument('--fan_out', type=int, default=10)
Expand Down Expand Up @@ -83,6 +85,7 @@ def arg_parse():
'GCN',
# TODO: Uncomment when we add support for disjoint sampling
# 'SGFormer',
# 'Polynormer',
],
help="Model used for training, default SAGE",
)
Expand All @@ -104,6 +107,7 @@ def create_loader(
replace,
batch_size,
stage_name,
disjoint,
shuffle=False,
):
print(f'Creating {stage_name} loader...')
Expand All @@ -115,17 +119,25 @@ def create_loader(
replace=replace,
batch_size=batch_size,
shuffle=shuffle,
disjoint=disjoint,
)


def train(model, train_loader):
def train(args, model, train_loader):
model.train()

total_loss = total_correct = total_examples = 0
for batch in train_loader:
batch = batch.cuda()
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)[:batch.batch_size]
if args.model in ['SGFormer', 'Polynormer']:
if args.model == 'Polynormer' and i == args.local_epochs:
print('start global attention')
model._global = True
out = model(batch.x, batch.edge_index,
batch.batch)[:batch.batch_size]
else:
out = model(batch.x, batch.edge_index)[:batch.batch_size]
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
loss = F.cross_entropy(out, y)
loss.backward()
Expand All @@ -140,13 +152,17 @@ def train(model, train_loader):


@torch.no_grad()
def test(model, loader):
def test(args, model, loader):
model.eval()

total_correct = total_examples = 0
for batch in loader:
batch = batch.cuda()
out = model(batch.x, batch.edge_index)[:batch.batch_size]
if args.model in ['SGFormer', 'Polynormer']:
out = model(batch.x, batch.edge_index,
batch.batch)[:batch.batch_size]
else:
out = model(batch.x, batch.edge_index)[:batch.batch_size]
y = batch.y[:batch.batch_size].view(-1).to(torch.long)

total_correct += out.argmax(dim=-1).eq(y).sum()
Expand All @@ -158,6 +174,11 @@ def test(model, loader):
if __name__ == '__main__':
args = arg_parse()
torch_geometric.seed_everything(123)
if args.model == 'Polynormer' and args.num_layers != 7:
print(
"The original polynormer paper recommends 7 layers, you have "
"chosen", args.num_layers, "which may effect results. "
"See for details")
if "papers" in str(args.dataset) and (psutil.virtual_memory().total /
(1024**3)) < 390:
print("Warning: may not have enough RAM to use this many GPUs.")
Expand Down Expand Up @@ -196,11 +217,13 @@ def test(model, loader):

print(f"Training {args.dataset} with {args.model} model.")
if args.model == "GAT":
model = torch_geometric.nn.models.GAT(dataset.num_features,
args.hidden_channels,
args.num_layers,
dataset.num_classes,
heads=args.num_heads).cuda()
model = torch_geometric.nn.models.GAT(
dataset.num_features,
args.hidden_channels,
args.num_layers,
dataset.num_classes,
heads=args.num_heads,
).cuda()
elif args.model == "GCN":
model = torch_geometric.nn.models.GCN(
dataset.num_features,
Expand All @@ -226,8 +249,16 @@ def test(model, loader):
gnn_num_layers=args.num_layers,
gnn_dropout=args.dropout,
).cuda()
elif args.model == 'Polynormer':
# TODO add support for this with disjoint sampling
model = torch_geometric.nn.models.Polynormer(
in_channels=dataset.num_features,
hidden_channels=args.hidden_channels,
out_channels=dataset.num_classes,
local_layers=args.num_layers,
).cuda()
else:
raise ValueError('Unsupported model type: {args.model}')
raise ValueError(f'Unsupported model type: {args.model}')

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=args.wd)
Expand All @@ -237,6 +268,7 @@ def test(model, loader):
num_neighbors=[args.fan_out] * args.num_layers,
replace=False,
batch_size=args.batch_size,
disjoint=args.model in ['SGFormer', 'Polynormer'],
)

train_loader = create_loader(
Expand Down Expand Up @@ -268,14 +300,16 @@ def test(model, loader):
best_val = 0.
start = time.perf_counter()
epochs = args.epochs
if args.model == 'Polynormer':
epochs += args.local_epochs
for epoch in range(1, epochs + 1):
train_start = time.perf_counter()
loss, train_acc = train(model, train_loader)
loss, train_acc = train(args, model, train_loader)
train_end = time.perf_counter()
train_times.append(train_end - train_start)
inference_start = time.perf_counter()
train_acc = test(model, train_loader)
val_acc = test(model, val_loader)
train_acc = test(args, model, train_loader)
val_acc = test(args, model, val_loader)

inference_times.append(time.perf_counter() - inference_start)
val_accs.append(val_acc)
Expand All @@ -300,7 +334,7 @@ def test(model, loader):
print(f"Best validation accuracy: {best_val:.4f}")

print("Testing...")
final_test_acc = test(model, test_loader)
final_test_acc = test(args, model, test_loader)
print(f'Test Accuracy: {final_test_acc:.4f}')

total_time = round(time.perf_counter() - wall_clock_start, 2)
Expand Down
Loading