Skip to content
Discussion options

You must be logged in to vote

It looks like you need to apply the following fixes:

  • data.pos needs to be of type floating-point
  • Radius(1) will only include connections with distance less than 1. You can add a small value to fix this
import math

import numpy as np
import torch

import torch_geometric.transforms as T
from torch_geometric.data import Data as gData

pre_trans = T.RadiusGraph(math.sqrt(2) + 1e-5)
test_graph = gData(
    x=torch.tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])), pos=torch.tensor(
        np.array([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0],
                  [2, 1], [2, 2]])))
test_graph.pos = test_graph.pos.to(torch.float)
pre_trans(test_graph)
print(test_graph.edge_index.shape)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@wwyi1828
Comment options

Answer selected by wwyi1828
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants