1313 * limitations under the License.
1414 */
1515#include " flex/engines/graph_db/app/builtin/pagerank.h"
16+ #include " flex/engines/graph_db/runtime/common/graph_interface.h"
17+ #include " flex/engines/graph_db/runtime/common/rt_any.h"
1618
1719namespace gs {
1820
19- results::CollectiveResults PageRank::Query (const GraphDBSession& sess,
20- std::string vertex_label,
21- std::string edge_label,
22- double damping_factor,
23- int max_iterations, double epsilon) {
21+ void write_result (
22+ const ReadTransaction& txn, results::CollectiveResults& results,
23+ const std::vector<std::tuple<label_t , vid_t , double >>& pagerank,
24+ int32_t result_limit) {
25+ runtime::GraphReadInterface graph (txn);
26+
27+ for (int32_t j = 0 ; j < std::min ((int32_t ) pagerank.size (), result_limit);
28+ ++j) {
29+ auto vertex_label = std::get<0 >(pagerank[j]);
30+ auto vertex_label_name = txn.schema ().get_vertex_label_name (vertex_label);
31+ auto vid = std::get<1 >(pagerank[j]);
32+ runtime::RTAny any (txn.GetVertexId (vertex_label, vid));
33+ auto result = results.add_results ();
34+ auto first_col = result->mutable_record ()->add_columns ();
35+ first_col->mutable_name_or_id ()->set_id (0 );
36+ first_col->mutable_entry ()->mutable_element ()->mutable_object ()->set_str (
37+ vertex_label_name);
38+
39+ auto oid_col = result->mutable_record ()->add_columns ();
40+ any.sink (graph, 1 , oid_col);
41+
42+ auto pagerank_col = result->mutable_record ()->add_columns ();
43+ pagerank_col->mutable_name_or_id ()->set_id (2 );
44+ pagerank_col->mutable_entry ()->mutable_element ()->mutable_object ()->set_f64 (
45+ std::get<2 >(pagerank[j]));
46+ }
47+ }
48+
49+ results::CollectiveResults PageRank::Query (
50+ const GraphDBSession& sess, std::string src_vertex_label,
51+ std::string dst_vertex_label, std::string edge_label, double damping_factor,
52+ int32_t max_iterations, double epsilon, int32_t result_limit) {
2453 auto txn = sess.GetReadTransaction ();
2554
26- if (!sess.schema ().has_vertex_label (vertex_label)) {
27- LOG (ERROR) << " The requested vertex label doesn't exits." ;
55+ if (!sess.schema ().has_vertex_label (src_vertex_label)) {
56+ LOG (ERROR) << " The requested src vertex label doesn't exits." ;
57+ return {};
58+ }
59+ if (!sess.schema ().has_vertex_label (dst_vertex_label)) {
60+ LOG (ERROR) << " The requested dst vertex label doesn't exits." ;
2861 return {};
2962 }
30- if (!sess.schema ().has_edge_label (vertex_label, vertex_label, edge_label)) {
63+ if (!sess.schema ().has_edge_label (src_vertex_label, dst_vertex_label,
64+ edge_label)) {
3165 LOG (ERROR) << " The requested edge label doesn't exits." ;
3266 return {};
3367 }
@@ -44,59 +78,117 @@ results::CollectiveResults PageRank::Query(const GraphDBSession& sess,
4478 return {};
4579 }
4680
47- auto vertex_label_id = sess.schema ().get_vertex_label_id (vertex_label);
81+ auto src_vertex_label_id =
82+ sess.schema ().get_vertex_label_id (src_vertex_label);
83+ auto dst_vertex_label_id =
84+ sess.schema ().get_vertex_label_id (dst_vertex_label);
4885 auto edge_label_id = sess.schema ().get_edge_label_id (edge_label);
4986
50- auto num_vertices = txn.GetVertexNum (vertex_label_id);
51-
52- std::unordered_map<vid_t , double > pagerank;
53- std::unordered_map<vid_t , double > new_pagerank;
87+ auto num_src_vertices = txn.GetVertexNum (src_vertex_label_id);
88+ auto num_dst_vertices = txn.GetVertexNum (dst_vertex_label_id);
89+ auto num_vertices = src_vertex_label_id == dst_vertex_label_id
90+ ? num_src_vertices
91+ : num_src_vertices + num_dst_vertices;
92+
93+ std::vector<std::vector<double >> pagerank;
94+ std::vector<std::vector<double >> new_pagerank;
95+ std::vector<std::vector<int32_t >> outdegree;
96+
97+ bool dst_to_src = src_vertex_label_id != dst_vertex_label_id &&
98+ txn.schema ().exist (dst_vertex_label_id, src_vertex_label_id,
99+ edge_label_id);
100+
101+ pagerank.emplace_back (std::vector<double >(num_src_vertices, 0.0 ));
102+ new_pagerank.emplace_back (std::vector<double >(num_src_vertices, 0.0 ));
103+ outdegree.emplace_back (std::vector<int32_t >(num_src_vertices, 0 ));
104+ if (dst_to_src) {
105+ pagerank.emplace_back (std::vector<double >(num_dst_vertices, 0.0 ));
106+ new_pagerank.emplace_back (std::vector<double >(num_dst_vertices, 0.0 ));
107+ outdegree.emplace_back (std::vector<int32_t >(num_dst_vertices, 0 ));
108+ }
54109
55- auto vertex_iter = txn.GetVertexIterator (vertex_label_id );
110+ auto src_vertex_iter = txn.GetVertexIterator (src_vertex_label_id );
56111
57- while (vertex_iter.IsValid ()) {
58- vid_t vid = vertex_iter.GetIndex ();
59- pagerank[vid] = 1.0 / num_vertices;
60- new_pagerank[vid] = 0.0 ;
61- vertex_iter.Next ();
112+ while (src_vertex_iter.IsValid ()) {
113+ vid_t vid = src_vertex_iter.GetIndex ();
114+ pagerank[0 ][vid] = 1.0 / num_vertices;
115+ new_pagerank[0 ][vid] = 0.0 ;
116+ src_vertex_iter.Next ();
117+ outdegree[0 ][vid] = txn.GetOutDegree (src_vertex_label_id, vid,
118+ dst_vertex_label_id, edge_label_id);
119+ }
120+ if (dst_to_src) {
121+ auto dst_vertex_iter = txn.GetVertexIterator (dst_vertex_label_id);
122+ while (dst_vertex_iter.IsValid ()) {
123+ vid_t vid = dst_vertex_iter.GetIndex ();
124+ pagerank[1 ][vid] = 1.0 / num_vertices;
125+ new_pagerank[1 ][vid] = 0.0 ;
126+ dst_vertex_iter.Next ();
127+ outdegree[1 ][vid] = txn.GetOutDegree (
128+ dst_vertex_label_id, src_vertex_label_id, vid, edge_label_id);
129+ }
62130 }
63-
64- std::unordered_map<vid_t , double > outdegree;
65131
66132 for (int iter = 0 ; iter < max_iterations; ++iter) {
67- for (auto & kv : new_pagerank) {
68- kv.second = 0.0 ;
133+ new_pagerank[0 ].assign (num_src_vertices, 0.0 );
134+ if (dst_to_src) {
135+ new_pagerank[1 ].assign (num_dst_vertices, 0.0 );
69136 }
70137
71- auto vertex_iter = txn.GetVertexIterator (vertex_label_id );
72- while (vertex_iter .IsValid ()) {
73- vid_t v = vertex_iter .GetIndex ();
138+ auto src_vertex_iter = txn.GetVertexIterator (src_vertex_label_id );
139+ while (src_vertex_iter .IsValid ()) {
140+ vid_t v = src_vertex_iter .GetIndex ();
74141
75142 double sum = 0.0 ;
76- auto edges = txn.GetInEdgeIterator (vertex_label_id, v, vertex_label_id,
77- edge_label_id);
78- while (edges.IsValid ()) {
79- auto neighbor = edges.GetNeighbor ();
80- if (outdegree[neighbor] == 0 ) {
81- auto out_edges = txn.GetOutEdgeIterator (
82- vertex_label_id, neighbor, vertex_label_id, edge_label_id);
83- while (out_edges.IsValid ()) {
84- outdegree[neighbor]++;
85- out_edges.Next ();
86- }
143+ {
144+ auto edges = txn.GetInEdgeIterator (dst_vertex_label_id, v,
145+ src_vertex_label_id, edge_label_id);
146+ while (edges.IsValid ()) {
147+ auto neighbor = edges.GetNeighbor ();
148+ sum += pagerank[0 ][neighbor] / outdegree[0 ][neighbor];
149+ edges.Next ();
87150 }
88- sum += pagerank[neighbor] / outdegree[neighbor];
89- edges.Next ();
90151 }
91152
92- new_pagerank[v] =
153+ new_pagerank[0 ][ v] =
93154 damping_factor * sum + (1.0 - damping_factor) / num_vertices;
94- vertex_iter.Next ();
155+ src_vertex_iter.Next ();
156+ }
157+
158+ if (dst_to_src) {
159+ auto dst_vertex_iter = txn.GetVertexIterator (dst_vertex_label_id);
160+ while (dst_vertex_iter.IsValid ()) {
161+ vid_t v = dst_vertex_iter.GetIndex ();
162+
163+ double sum = 0.0 ;
164+ {
165+ auto edges = txn.GetInEdgeIterator (
166+ src_vertex_label_id, v, dst_vertex_label_id, edge_label_id);
167+ while (edges.IsValid ()) {
168+ LOG (INFO) << " got edge, from " << edges.GetNeighbor () << " to " << v
169+ << " label: " << std::to_string (src_vertex_label_id)
170+ << " " << std::to_string (dst_vertex_label_id) << " "
171+ << std::to_string (edge_label_id);
172+ auto neighbor = edges.GetNeighbor ();
173+ sum += pagerank[1 ][neighbor] / outdegree[1 ][neighbor];
174+ edges.Next ();
175+ }
176+ }
177+
178+ new_pagerank[1 ][v] =
179+ damping_factor * sum + (1.0 - damping_factor) / num_vertices;
180+ dst_vertex_iter.Next ();
181+ }
95182 }
96183
97184 double diff = 0.0 ;
98- for (const auto & kv : pagerank) {
99- diff += std::abs (new_pagerank[kv.first ] - kv.second );
185+ for (size_t j = 0 ; j < new_pagerank[0 ].size (); ++j) {
186+ diff += std::abs (new_pagerank[0 ][j] - pagerank[0 ][j]);
187+ }
188+ if (dst_to_src) {
189+ for (size_t j = 0 ; j < new_pagerank[1 ].size (); ++j) {
190+ diff += std::abs (new_pagerank[1 ][j] - pagerank[1 ][j]);
191+ }
100192 }
101193
102194 if (diff < epsilon) {
@@ -108,28 +200,23 @@ results::CollectiveResults PageRank::Query(const GraphDBSession& sess,
108200
109201 results::CollectiveResults results;
110202
111- for (auto kv : pagerank) {
112- int64_t oid_ = txn.GetVertexId (vertex_label_id, kv.first ).AsInt64 ();
113- auto result = results.add_results ();
114- result->mutable_record ()
115- ->add_columns ()
116- ->mutable_entry ()
117- ->mutable_element ()
118- ->mutable_object ()
119- ->set_str (vertex_label);
120- result->mutable_record ()
121- ->add_columns ()
122- ->mutable_entry ()
123- ->mutable_element ()
124- ->mutable_object ()
125- ->set_i64 (oid_);
126- result->mutable_record ()
127- ->add_columns ()
128- ->mutable_entry ()
129- ->mutable_element ()
130- ->mutable_object ()
131- ->set_f64 (kv.second );
203+ std::vector<std::tuple<label_t , vid_t , double >> final_pagerank (num_vertices);
204+ for (size_t i = 0 ; i < pagerank[0 ].size (); ++i) {
205+ final_pagerank[i] = std::make_tuple (src_vertex_label_id, i, pagerank[0 ][i]);
132206 }
207+ if (dst_to_src) {
208+ for (size_t i = 0 ; i < pagerank[1 ].size (); ++i) {
209+ final_pagerank[i + num_src_vertices] =
210+ std::make_tuple (dst_vertex_label_id, i, pagerank[1 ][i]);
211+ }
212+ }
213+ std::sort (final_pagerank.begin (), final_pagerank.end (),
214+ [](const std::tuple<label_t , vid_t , double >& a,
215+ const std::tuple<label_t , vid_t , double >& b) {
216+ return std::get<2 >(a) > std::get<2 >(b);
217+ });
218+
219+ write_result (txn, results, final_pagerank, result_limit);
133220
134221 txn.Commit ();
135222 return results;
0 commit comments