Question on batch.n_id, batch.input_id which generated by NeighborLoader() #8314
-
Hi, all PyG is really useful ! But I have a question on batch.n_id, batch.input_id generated by NeighborLoader(). In the docs, it claims
I think the Best, My code datamodule = LightningNodeData(
data=graph,
input_train_nodes = graph.train_mask,
input_val_nodes = graph.val_mask,
input_test_nodes = graph.test_mask,
loader='neighbor',
num_neighbors=CONFIG.transfer.loader.num_neighbors,
batch_size=CONFIG.transfer.loader.batch_size,
subgraph_type="bidirectional",
disjoint=True,
shuffle=True,
num_workers=CONFIG.transfer.loader.num_workers,
)
trainer = Trainer(
accelerator=config.device,
devices=device_id,
callbacks=[progress_bar, early_stop, model_checkpoint, lr_monitor],
precision=precision,
default_root_dir=save_dir,
strategy=strategy,
gradient_clip_val=config.gradient_clip_val,
logger=logger,
log_every_n_steps=config.log_every_n_steps,
max_epochs=config.epochs,
)
trainer.fit(model, datamodule=data) Data>>> batch
Data(x=[1543, 2500], edge_index=[2, 2830], train_mask=[1543], val_mask=[1543], test_mask=[1543], n_id=[1543], e_id=[2830], batch=[1543], input_id=[128], batch_size=128)
>>> batch.n_id[:128]
tensor([111360, 129061, 15483, 2430, 133970, 113119, 121077, 23134, 64944,
153408, 45093, 124491, 166141, 66873, 134996, 14232, 26973, 147208,
98417, 47202, 65048, 11656, 94083, 108200, 85756, 105963, 15051,
86647, 18871, 121700, 96400, 12951, 145768, 54487, 143461, 84269,
158634, 107187, 163444, 101727, 95188, 5004, 64955, 22269, 51641,
154142, 76768, 50338, 66808, 131496, 132804, 30637, 19455, 79564,
146635, 52605, 42512, 88800, 124091, 46018, 138140, 52277, 156148,
75590, 115781, 101977, 109965, 123812, 47807, 160965, 50005, 60997,
154879, 41772, 28247, 80973, 67898, 102251, 98882, 71784, 30744,
120500, 49730, 103174, 49850, 14590, 53873, 112457, 9370, 119005,
15163, 140524, 27497, 152293, 116393, 15735, 63396, 23356, 123697,
149211, 84516, 164876, 83821, 14096, 16759, 101058, 94775, 76286,
110523, 52011, 52037, 21314, 142840, 60790, 64035, 112363, 159576,
135442, 93875, 164150, 124530, 117294, 156725, 7850, 139545, 148136,
27110, 64132], device='cuda:3')
>>> batch.input_id
tensor([ 89226, 103303, 12353, 1927, 107198, 90617, 96971, 18443, 51915,
122793, 36001, 99698, 132933, 53493, 108026, 11360, 21500, 117796,
78854, 37695, 51994, 9305, 75374, 86701, 68625, 84905, 12005,
69347, 15056, 97476, 77221, 10338, 116632, 43548, 114791, 67432,
126872, 85882, 130747, 81510, 76267, 3966, 51923, 17778, 41240,
123350, 61379, 40184, 53443, 105233, 106271, 24407, 15516, 63627,
117342, 42026, 33906, 71095, 99377, 36760, 110514, 41764, 124946,
60433, 92710, 81717, 88101, 99157, 38183, 128745, 39923, 48745,
123943, 33317, 22505, 64783, 54309, 81937, 79225, 57396, 24492,
96514, 39701, 82672, 39802, 11636, 43050, 90102, 7458, 95314,
12094, 112447, 21920, 121899, 93214, 12538, 50689, 18613, 99060,
119417, 67631, 131893, 67081, 11254, 13363, 80987, 75939, 60997,
88553, 41553, 41577, 16988, 114289, 48573, 51201, 90027, 127627,
108380, 75207, 131320, 99726, 93940, 125394, 6237, 111653, 118549,
21610, 51271], device='cuda:3')
>>>batch.n_id[:128] - batch.input_id
tensor([22134, 25758, 3130, 503, 26772, 22502, 24106, 4691, 13029, 30615,
9092, 24793, 33208, 13380, 26970, 2872, 5473, 29412, 19563, 9507,
13054, 2351, 18709, 21499, 17131, 21058, 3046, 17300, 3815, 24224,
19179, 2613, 29136, 10939, 28670, 16837, 31762, 21305, 32697, 20217,
18921, 1038, 13032, 4491, 10401, 30792, 15389, 10154, 13365, 26263,
26533, 6230, 3939, 15937, 29293, 10579, 8606, 17705, 24714, 9258,
27626, 10513, 31202, 15157, 23071, 20260, 21864, 24655, 9624, 32220,
10082, 12252, 30936, 8455, 5742, 16190, 13589, 20314, 19657, 14388,
6252, 23986, 10029, 20502, 10048, 2954, 10823, 22355, 1912, 23691,
3069, 28077, 5577, 30394, 23179, 3197, 12707, 4743, 24637, 29794,
16885, 32983, 16740, 2842, 3396, 20071, 18836, 15289, 21970, 10458,
10460, 4326, 28551, 12217, 12834, 22336, 31949, 27062, 18668, 32830,
24804, 23354, 31331, 1613, 27892, 29587, 5500, 12861],
device='cuda:3')
>>>batch.batch[:128]
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127], device='cuda:3') |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Any helps ? |
Beta Was this translation helpful? Give feedback.
Hi @xiachenrui, in my understanding
input_id
are not node IDs, but indices to initial seed nodes. Let me give you an example (a slightly modified version of the code available in this tutorial).