Skip to content

Commit fa07f86

Browse files
authored
feat(learning): feature_store & graph_store V1 (#4237)
<!-- Thanks for your contribution! please review https://github.com/alibaba/GraphScope/blob/main/CONTRIBUTING.md before opening an issue. --> ## What do these changes do? Step 1: Implement GraphScope-based PyG Remote Backend and complete the end-to-end integration of GraphScope and PyG. (Finished) Step 2: Get data from the Server through PyG Remote Backend and support sampling on the Client side. (Finished) ## Related issue number <!-- Are there any issues opened that will be resolved by merging this change? --> PyG Remote Backend Based on GraphScope #3739
1 parent 49b2643 commit fa07f86

File tree

5 files changed

+534
-1
lines changed

5 files changed

+534
-1
lines changed

python/graphscope/client/session.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,8 +1331,11 @@ def graphlearn_torch(
13311331
num_clients=1,
13321332
manifest_path=None,
13331333
client_folder_path="./",
1334+
return_pyg_remote_backend=False,
13341335
):
13351336
from graphscope.learning.gl_torch_graph import GLTorchGraph
1337+
from graphscope.learning.gs_feature_store import GsFeatureStore
1338+
from graphscope.learning.gs_graph_store import GsGraphStore
13361339
from graphscope.learning.utils import fill_params_in_yaml
13371340
from graphscope.learning.utils import read_folder_files_content
13381341

@@ -1380,6 +1383,12 @@ def graphlearn_torch(
13801383
g = GLTorchGraph(endpoints)
13811384
self._learning_instance_dict[graph.vineyard_id] = g
13821385
graph._attach_learning_instance(g)
1386+
1387+
if return_pyg_remote_backend:
1388+
feature_store = GsFeatureStore(config)
1389+
graph_store = GsGraphStore(config)
1390+
return g, feature_store, graph_store
1391+
13831392
return g
13841393

13851394
def nx(self):
@@ -1682,6 +1691,7 @@ def graphlearn_torch(
16821691
num_clients=1,
16831692
manifest_path=None,
16841693
client_folder_path="./",
1694+
return_pyg_remote_backend=False,
16851695
):
16861696
assert graph is not None, "graph cannot be None"
16871697
assert (
@@ -1699,4 +1709,5 @@ def graphlearn_torch(
16991709
num_clients,
17001710
manifest_path,
17011711
client_folder_path,
1712+
return_pyg_remote_backend,
17021713
) # pylint: disable=protected-access
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from ogb.nodeproppred import Evaluator
4+
from torch_geometric.data.feature_store import TensorAttr
5+
from torch_geometric.loader import NeighborLoader
6+
from torch_geometric.nn import GraphSAGE
7+
from tqdm import tqdm
8+
9+
import graphscope as gs
10+
import graphscope.learning.graphlearn_torch as glt
11+
from graphscope.dataset import load_ogbn_arxiv
12+
13+
NUM_EPOCHS = 10
14+
BATCH_SIZE = 4096
15+
NUM_SERVERS = 1
16+
NUM_NEIGHBORS = [2, 2, 2]
17+
18+
print("Batch size:", BATCH_SIZE)
19+
print("Number of epochs:", NUM_EPOCHS)
20+
print("Number of neighbors:", NUM_NEIGHBORS)
21+
22+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23+
print("Using device:", device)
24+
25+
gs.set_option(show_log=True)
26+
27+
# load the ogbn_arxiv graph.
28+
sess = gs.session(cluster_type="hosts", num_workers=NUM_SERVERS)
29+
g = load_ogbn_arxiv(sess=sess)
30+
31+
print("-- Initializing store ...")
32+
glt_graph, feature_store, graph_store = gs.graphlearn_torch(
33+
g,
34+
edges=[
35+
("paper", "citation", "paper"),
36+
],
37+
node_features={
38+
"paper": [f"feat_{i}" for i in range(128)],
39+
},
40+
node_labels={
41+
"paper": "label",
42+
},
43+
edge_dir="out",
44+
random_node_split={
45+
"num_val": 0.1,
46+
"num_test": 0.1,
47+
},
48+
return_pyg_remote_backend=True,
49+
)
50+
51+
print("-- Initializing client ...")
52+
glt.distributed.init_client(
53+
num_servers=1,
54+
num_clients=1,
55+
client_rank=0,
56+
master_addr=glt_graph.master_addr,
57+
master_port=glt_graph.server_client_master_port,
58+
num_rpc_threads=4,
59+
is_dynamic=True,
60+
)
61+
62+
63+
print("-- Initializing loader ...")
64+
# get train & test mask
65+
num_nodes = feature_store.get_tensor_size(TensorAttr(group_name="paper"))[0]
66+
print("Node num:", num_nodes)
67+
shuffle_id = torch.randperm(num_nodes)
68+
train_indices = shuffle_id[: int(0.8 * num_nodes)]
69+
test_indices = shuffle_id[int(0.2 * num_nodes) :]
70+
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
71+
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
72+
train_mask[train_indices] = True
73+
test_mask[test_indices] = True
74+
75+
train_loader = NeighborLoader(
76+
data=(feature_store, graph_store),
77+
batch_size=BATCH_SIZE,
78+
num_neighbors=NUM_NEIGHBORS,
79+
shuffle=False,
80+
input_nodes=("paper", train_mask),
81+
)
82+
83+
test_loader = NeighborLoader(
84+
data=(feature_store, graph_store),
85+
batch_size=BATCH_SIZE,
86+
num_neighbors=NUM_NEIGHBORS,
87+
shuffle=False,
88+
input_nodes=("paper", test_mask),
89+
)
90+
91+
model = GraphSAGE(
92+
in_channels=128,
93+
hidden_channels=256,
94+
num_layers=3,
95+
out_channels=47,
96+
).to(device)
97+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
98+
99+
100+
@torch.no_grad()
101+
def test(model, test_loader, dataset_name):
102+
evaluator = Evaluator(name=dataset_name)
103+
model.eval()
104+
xs = []
105+
y_true = []
106+
for i, batch in enumerate(test_loader):
107+
if i == 0:
108+
device = batch["paper"].x.device
109+
batch["paper"].x = batch["paper"].x.to(torch.float32) # TODO
110+
x = model(batch["paper"].x, batch[("paper", "citation", "paper")].edge_index)[
111+
: batch["paper"].batch_size
112+
]
113+
xs.append(x.cpu())
114+
y_true.append(batch["paper"].label[: batch["paper"].batch_size].cpu())
115+
del batch
116+
117+
xs = [t.to(device) for t in xs]
118+
y_true = [t.to(device) for t in y_true]
119+
y_pred = torch.cat(xs, dim=0).argmax(dim=-1, keepdim=True)
120+
y_true = torch.cat(y_true, dim=0).unsqueeze(-1)
121+
test_acc = evaluator.eval(
122+
{
123+
"y_true": y_true,
124+
"y_pred": y_pred,
125+
}
126+
)["acc"]
127+
return test_acc
128+
129+
130+
dataset_name = "ogbn-arxiv"
131+
for epoch in range(NUM_EPOCHS):
132+
model.train()
133+
with tqdm(
134+
total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", unit="batch"
135+
) as pbar:
136+
for batch in train_loader:
137+
optimizer.zero_grad()
138+
batch["paper"].x = batch["paper"].x.to(torch.float32) # TODO
139+
out = model(
140+
batch["paper"].x, batch[("paper", "citation", "paper")].edge_index
141+
)[: batch["paper"].batch_size].log_softmax(dim=-1)
142+
label = batch["paper"].label[: batch["paper"].batch_size].long()
143+
loss = F.nll_loss(out, label)
144+
loss.backward()
145+
optimizer.step()
146+
pbar.set_postfix({"Loss": f"{loss:.4f}"})
147+
pbar.update(1)
148+
149+
# Test accuracy.
150+
if epoch % 2 == 0:
151+
test_acc = test(model, test_loader, dataset_name)
152+
print(f"-- Test Accuracy: {test_acc:.4f}", flush=True)
153+
154+
print("-- Shutdowning ...")
155+
glt.distributed.shutdown_client()
156+
157+
print("-- Exited ...")

0 commit comments

Comments
 (0)