How to customize training loop? #7549
-
I want to marry lightning and https://pytorch-geometric.readthedocs.io/en/latest/ or in particular https://pytorch-geometric-temporal.readthedocs.io/en/latest/ When following the basic examples on their website such as for the ChickenpoxDatasetLoader() a RecurrentGCN is constructed. For me being a total newbie for lightning it is already pretty clear how t convert that to a regular lightning module - kudos to the easy API so far. However, it is rather unclear for me how to put the training loop into a lightning compatible trainer: from tqdm import tqdm
model = RecurrentGCN(node_features = 4) # chickenpox
model
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in tqdm(range(200)):
cost = 0
for time, snapshot in enumerate(train_dataset):
y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost.backward()
optimizer.step()
optimizer.zero_grad() Would I need a custom trainer in lightning? import pandas as pd
df = pd.DataFrame({'person_1':[], 'person_2':[], 'time':[], 'lat':[], 'long':[], 'hobby':[]})
display(df) I want to perform link prediction - i.e. recommend new friends based on similar hobbies in similar locations & time ranges. With that being said: in the pytorch-geometric-temporal framework they denote snapshots over time (this is not meant as batches, currently they assume that a snapshot contains all the data in a single batch for that particular span of time). However, the default trainer does not offer this functionality to iterate over snapshots and to me it is unclear how to include it. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
just found https://github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/master/examples/lightning_example.py - seems to be the answer to my question - almost, except it is not covering how to handle the iteration over the temporal snapshots. |
Beta Was this translation helpful? Give feedback.
-
I dont know much about pytorch geometric but are you trying to have an additional state (tensor) that is passed in between steps? |
Beta Was this translation helpful? Give feedback.
just found https://github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/master/examples/lightning_example.py - seems to be the answer to my question - almost, except it is not covering how to handle the iteration over the temporal snapshots.