@@ -40,7 +40,7 @@ class DygraphModel():
4040 # define model
4141 def create_model (self , config ):
4242 pred_edges = config .get ('hyper_parameters.pred_edges' , 1 )
43- dim = config .get ('hyper_parameters.pred_edges ' , 8 )
43+ dim = config .get ('hyper_parameters.dim ' , 8 )
4444 hidden_layer = config .get ('hyper_parameters.hidden_layer' , 64 )
4545 l0_para = config .get ('hyper_parameters.l0_para' , [0.66 , - 0.1 , 1.1 ])
4646 batch_size = config .get ('runner.train_batch_size' , 8 )
@@ -66,7 +66,12 @@ def create_feeds(self, batch_data, config):
6666 labels .append (batch_data [4 ][i ].numpy ())
6767 graphs = pgl .Graph .batch (graphs ).tensor ()
6868 labels = paddle .to_tensor (labels , dtype = 'float32' )
69- return graphs , labels
69+ edges = np .array (graphs .edges , dtype = "int32" )
70+ node_feat = np .array (graphs .node_feat ["node_attr" ], dtype = "int32" )
71+ edge_feat = np .array (graphs .edge_feat ["edge_attr" ], dtype = "int32" )
72+ segment_ids = graphs .graph_node_id
73+
74+ return edges , node_feat , edge_feat , segment_ids , labels
7075
7176 # define loss function by predicts and label
7277 def create_loss (self , output , label , l0_penaty , l2_penaty , l0_weight ,
@@ -100,9 +105,11 @@ def create_metrics(self):
100105
101106 # construct train forward phase
102107 def train_forward (self , dy_model , metrics_list , batch_data , config ):
103- graphs , labels = self .create_feeds (batch_data , config )
108+ edges , node_feat , edge_feat , segment_ids , labels = self .create_feeds (
109+ batch_data , config )
104110 # predict
105- output , l0_penaty , l2_penaty = dy_model .forward (graphs , True )
111+ output , l0_penaty , l2_penaty = dy_model .forward (
112+ edges , node_feat , edge_feat , segment_ids , True )
106113 # get loss
107114 l0_weight = config .get ("hyper_parameters.l0_weight" , 0.001 )
108115 l2_weight = config .get ("hyper_parameters.l0_weight" , 0.001 )
@@ -122,9 +129,11 @@ def train_forward(self, dy_model, metrics_list, batch_data, config):
122129
123130 # construct infer forward phase
124131 def infer_forward (self , dy_model , metrics_list , batch_data , config ):
125- graphs , labels = self .create_feeds (batch_data , config )
132+ edges , node_feat , edge_feat , segment_ids , labels = self .create_feeds (
133+ batch_data , config )
126134 # predict
127- output , _ , _ = dy_model .forward (graphs , False )
135+ output , _ , _ = dy_model .forward (edges , node_feat , edge_feat ,
136+ segment_ids , False )
128137 # update metrics
129138 predictions = np .vstack (output )
130139 labels = np .vstack (labels )
0 commit comments