Skip to content

Commit 66264ff

Browse files
authored
Merge pull request #1 from BamLubi/bugfix-sign2static
fix: dynamic graph to static graph for SIGN model
2 parents b34af92 + 9b3f229 commit 66264ff

File tree

6 files changed

+1558
-23
lines changed

6 files changed

+1558
-23
lines changed

models/rank/sign/dygraph_model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class DygraphModel():
4040
# define model
4141
def create_model(self, config):
4242
pred_edges = config.get('hyper_parameters.pred_edges', 1)
43-
dim = config.get('hyper_parameters.pred_edges', 8)
43+
dim = config.get('hyper_parameters.dim', 8)
4444
hidden_layer = config.get('hyper_parameters.hidden_layer', 64)
4545
l0_para = config.get('hyper_parameters.l0_para', [0.66, -0.1, 1.1])
4646
batch_size = config.get('runner.train_batch_size', 8)
@@ -66,7 +66,12 @@ def create_feeds(self, batch_data, config):
6666
labels.append(batch_data[4][i].numpy())
6767
graphs = pgl.Graph.batch(graphs).tensor()
6868
labels = paddle.to_tensor(labels, dtype='float32')
69-
return graphs, labels
69+
edges = np.array(graphs.edges, dtype="int32")
70+
node_feat = np.array(graphs.node_feat["node_attr"], dtype="int32")
71+
edge_feat = np.array(graphs.edge_feat["edge_attr"], dtype="int32")
72+
segment_ids = graphs.graph_node_id
73+
74+
return edges, node_feat, edge_feat, segment_ids, labels
7075

7176
# define loss function by predicts and label
7277
def create_loss(self, output, label, l0_penaty, l2_penaty, l0_weight,
@@ -100,9 +105,11 @@ def create_metrics(self):
100105

101106
# construct train forward phase
102107
def train_forward(self, dy_model, metrics_list, batch_data, config):
103-
graphs, labels = self.create_feeds(batch_data, config)
108+
edges, node_feat, edge_feat, segment_ids, labels = self.create_feeds(
109+
batch_data, config)
104110
# predict
105-
output, l0_penaty, l2_penaty = dy_model.forward(graphs, True)
111+
output, l0_penaty, l2_penaty = dy_model.forward(
112+
edges, node_feat, edge_feat, segment_ids, True)
106113
# get loss
107114
l0_weight = config.get("hyper_parameters.l0_weight", 0.001)
108115
l2_weight = config.get("hyper_parameters.l0_weight", 0.001)
@@ -122,9 +129,11 @@ def train_forward(self, dy_model, metrics_list, batch_data, config):
122129

123130
# construct infer forward phase
124131
def infer_forward(self, dy_model, metrics_list, batch_data, config):
125-
graphs, labels = self.create_feeds(batch_data, config)
132+
edges, node_feat, edge_feat, segment_ids, labels = self.create_feeds(
133+
batch_data, config)
126134
# predict
127-
output, _, _ = dy_model.forward(graphs, False)
135+
output, _, _ = dy_model.forward(edges, node_feat, edge_feat,
136+
segment_ids, False)
128137
# update metrics
129138
predictions = np.vstack(output)
130139
labels = np.vstack(labels)

0 commit comments

Comments
 (0)