-
Hi, when I tried to train my GNN model, I got an error like this. ---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-13-558bee3a48e3> in <module>
1 for epoch in range(num_epochs):
2 losses = []
----> 3 train_loss = train(epoch)
4 test_loss = test(epoch)
5 print('Epoch: {:02d}, Train MSE: {:.4f}, Test MSE: {:.4f}'.format(epoch, train_loss, test_loss))
<ipython-input-7-08d4439ea327> in train(epoch)
13 total_loss += loss
14 losses.append(loss)
---> 15 loss.backward()
16 optimizer.step()
17
~/miniconda3/envs/gnb_df_pyg/lib/python3.7/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
253 create_graph=create_graph,
254 inputs=inputs)
--> 255 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
256
257 def register_hook(self, hook):
~/miniconda3/envs/gnb_df_pyg/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
147 Variable._execution_engine.run_backward(
148 tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 149 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
150
151
RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)` My loss function looks like this. (for variational model) def final_loss(bce_loss, mu, logvar):
BCE = bce_loss
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD I thought it might be an issue from large |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Thanks for the issue. The cause of this error looks super hard to tell from looking at the error message. What happens if you run the script via |
Beta Was this translation helpful? Give feedback.
Thanks for the issue. The cause of this error looks super hard to tell from looking at the error message. What happens if you run the script via
CUDA_LAUNCH_BLOCKING=1
ortorch.autograd.detect_anomaly()
? Can you also share some more information about your model architecture? Does there exist some minimal reproducible example?