Skip to content

Commit 5f02377

Browse files
authored
refactor(interactive): Support running pagerank on edge triplet subgraph (#4537)
Preivously, builtin pagerank suppose that the src and dst vertex of the edge label is of the same label. In this PR, we remove this constraint.
1 parent aae2c37 commit 5f02377

File tree

6 files changed

+186
-70
lines changed

6 files changed

+186
-70
lines changed

flex/engines/graph_db/app/builtin/pagerank.cc

Lines changed: 151 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,55 @@
1313
* limitations under the License.
1414
*/
1515
#include "flex/engines/graph_db/app/builtin/pagerank.h"
16+
#include "flex/engines/graph_db/runtime/common/graph_interface.h"
17+
#include "flex/engines/graph_db/runtime/common/rt_any.h"
1618

1719
namespace gs {
1820

19-
results::CollectiveResults PageRank::Query(const GraphDBSession& sess,
20-
std::string vertex_label,
21-
std::string edge_label,
22-
double damping_factor,
23-
int max_iterations, double epsilon) {
21+
void write_result(
22+
const ReadTransaction& txn, results::CollectiveResults& results,
23+
const std::vector<std::tuple<label_t, vid_t, double>>& pagerank,
24+
int32_t result_limit) {
25+
runtime::GraphReadInterface graph(txn);
26+
27+
for (int32_t j = 0; j < std::min((int32_t) pagerank.size(), result_limit);
28+
++j) {
29+
auto vertex_label = std::get<0>(pagerank[j]);
30+
auto vertex_label_name = txn.schema().get_vertex_label_name(vertex_label);
31+
auto vid = std::get<1>(pagerank[j]);
32+
runtime::RTAny any(txn.GetVertexId(vertex_label, vid));
33+
auto result = results.add_results();
34+
auto first_col = result->mutable_record()->add_columns();
35+
first_col->mutable_name_or_id()->set_id(0);
36+
first_col->mutable_entry()->mutable_element()->mutable_object()->set_str(
37+
vertex_label_name);
38+
39+
auto oid_col = result->mutable_record()->add_columns();
40+
any.sink(graph, 1, oid_col);
41+
42+
auto pagerank_col = result->mutable_record()->add_columns();
43+
pagerank_col->mutable_name_or_id()->set_id(2);
44+
pagerank_col->mutable_entry()->mutable_element()->mutable_object()->set_f64(
45+
std::get<2>(pagerank[j]));
46+
}
47+
}
48+
49+
results::CollectiveResults PageRank::Query(
50+
const GraphDBSession& sess, std::string src_vertex_label,
51+
std::string dst_vertex_label, std::string edge_label, double damping_factor,
52+
int32_t max_iterations, double epsilon, int32_t result_limit) {
2453
auto txn = sess.GetReadTransaction();
2554

26-
if (!sess.schema().has_vertex_label(vertex_label)) {
27-
LOG(ERROR) << "The requested vertex label doesn't exits.";
55+
if (!sess.schema().has_vertex_label(src_vertex_label)) {
56+
LOG(ERROR) << "The requested src vertex label doesn't exits.";
57+
return {};
58+
}
59+
if (!sess.schema().has_vertex_label(dst_vertex_label)) {
60+
LOG(ERROR) << "The requested dst vertex label doesn't exits.";
2861
return {};
2962
}
30-
if (!sess.schema().has_edge_label(vertex_label, vertex_label, edge_label)) {
63+
if (!sess.schema().has_edge_label(src_vertex_label, dst_vertex_label,
64+
edge_label)) {
3165
LOG(ERROR) << "The requested edge label doesn't exits.";
3266
return {};
3367
}
@@ -44,59 +78,117 @@ results::CollectiveResults PageRank::Query(const GraphDBSession& sess,
4478
return {};
4579
}
4680

47-
auto vertex_label_id = sess.schema().get_vertex_label_id(vertex_label);
81+
auto src_vertex_label_id =
82+
sess.schema().get_vertex_label_id(src_vertex_label);
83+
auto dst_vertex_label_id =
84+
sess.schema().get_vertex_label_id(dst_vertex_label);
4885
auto edge_label_id = sess.schema().get_edge_label_id(edge_label);
4986

50-
auto num_vertices = txn.GetVertexNum(vertex_label_id);
51-
52-
std::unordered_map<vid_t, double> pagerank;
53-
std::unordered_map<vid_t, double> new_pagerank;
87+
auto num_src_vertices = txn.GetVertexNum(src_vertex_label_id);
88+
auto num_dst_vertices = txn.GetVertexNum(dst_vertex_label_id);
89+
auto num_vertices = src_vertex_label_id == dst_vertex_label_id
90+
? num_src_vertices
91+
: num_src_vertices + num_dst_vertices;
92+
93+
std::vector<std::vector<double>> pagerank;
94+
std::vector<std::vector<double>> new_pagerank;
95+
std::vector<std::vector<int32_t>> outdegree;
96+
97+
bool dst_to_src = src_vertex_label_id != dst_vertex_label_id &&
98+
txn.schema().exist(dst_vertex_label_id, src_vertex_label_id,
99+
edge_label_id);
100+
101+
pagerank.emplace_back(std::vector<double>(num_src_vertices, 0.0));
102+
new_pagerank.emplace_back(std::vector<double>(num_src_vertices, 0.0));
103+
outdegree.emplace_back(std::vector<int32_t>(num_src_vertices, 0));
104+
if (dst_to_src) {
105+
pagerank.emplace_back(std::vector<double>(num_dst_vertices, 0.0));
106+
new_pagerank.emplace_back(std::vector<double>(num_dst_vertices, 0.0));
107+
outdegree.emplace_back(std::vector<int32_t>(num_dst_vertices, 0));
108+
}
54109

55-
auto vertex_iter = txn.GetVertexIterator(vertex_label_id);
110+
auto src_vertex_iter = txn.GetVertexIterator(src_vertex_label_id);
56111

57-
while (vertex_iter.IsValid()) {
58-
vid_t vid = vertex_iter.GetIndex();
59-
pagerank[vid] = 1.0 / num_vertices;
60-
new_pagerank[vid] = 0.0;
61-
vertex_iter.Next();
112+
while (src_vertex_iter.IsValid()) {
113+
vid_t vid = src_vertex_iter.GetIndex();
114+
pagerank[0][vid] = 1.0 / num_vertices;
115+
new_pagerank[0][vid] = 0.0;
116+
src_vertex_iter.Next();
117+
outdegree[0][vid] = txn.GetOutDegree(src_vertex_label_id, vid,
118+
dst_vertex_label_id, edge_label_id);
119+
}
120+
if (dst_to_src) {
121+
auto dst_vertex_iter = txn.GetVertexIterator(dst_vertex_label_id);
122+
while (dst_vertex_iter.IsValid()) {
123+
vid_t vid = dst_vertex_iter.GetIndex();
124+
pagerank[1][vid] = 1.0 / num_vertices;
125+
new_pagerank[1][vid] = 0.0;
126+
dst_vertex_iter.Next();
127+
outdegree[1][vid] = txn.GetOutDegree(
128+
dst_vertex_label_id, src_vertex_label_id, vid, edge_label_id);
129+
}
62130
}
63-
64-
std::unordered_map<vid_t, double> outdegree;
65131

66132
for (int iter = 0; iter < max_iterations; ++iter) {
67-
for (auto& kv : new_pagerank) {
68-
kv.second = 0.0;
133+
new_pagerank[0].assign(num_src_vertices, 0.0);
134+
if (dst_to_src) {
135+
new_pagerank[1].assign(num_dst_vertices, 0.0);
69136
}
70137

71-
auto vertex_iter = txn.GetVertexIterator(vertex_label_id);
72-
while (vertex_iter.IsValid()) {
73-
vid_t v = vertex_iter.GetIndex();
138+
auto src_vertex_iter = txn.GetVertexIterator(src_vertex_label_id);
139+
while (src_vertex_iter.IsValid()) {
140+
vid_t v = src_vertex_iter.GetIndex();
74141

75142
double sum = 0.0;
76-
auto edges = txn.GetInEdgeIterator(vertex_label_id, v, vertex_label_id,
77-
edge_label_id);
78-
while (edges.IsValid()) {
79-
auto neighbor = edges.GetNeighbor();
80-
if (outdegree[neighbor] == 0) {
81-
auto out_edges = txn.GetOutEdgeIterator(
82-
vertex_label_id, neighbor, vertex_label_id, edge_label_id);
83-
while (out_edges.IsValid()) {
84-
outdegree[neighbor]++;
85-
out_edges.Next();
86-
}
143+
{
144+
auto edges = txn.GetInEdgeIterator(dst_vertex_label_id, v,
145+
src_vertex_label_id, edge_label_id);
146+
while (edges.IsValid()) {
147+
auto neighbor = edges.GetNeighbor();
148+
sum += pagerank[0][neighbor] / outdegree[0][neighbor];
149+
edges.Next();
87150
}
88-
sum += pagerank[neighbor] / outdegree[neighbor];
89-
edges.Next();
90151
}
91152

92-
new_pagerank[v] =
153+
new_pagerank[0][v] =
93154
damping_factor * sum + (1.0 - damping_factor) / num_vertices;
94-
vertex_iter.Next();
155+
src_vertex_iter.Next();
156+
}
157+
158+
if (dst_to_src) {
159+
auto dst_vertex_iter = txn.GetVertexIterator(dst_vertex_label_id);
160+
while (dst_vertex_iter.IsValid()) {
161+
vid_t v = dst_vertex_iter.GetIndex();
162+
163+
double sum = 0.0;
164+
{
165+
auto edges = txn.GetInEdgeIterator(
166+
src_vertex_label_id, v, dst_vertex_label_id, edge_label_id);
167+
while (edges.IsValid()) {
168+
LOG(INFO) << "got edge, from " << edges.GetNeighbor() << " to " << v
169+
<< " label: " << std::to_string(src_vertex_label_id)
170+
<< " " << std::to_string(dst_vertex_label_id) << " "
171+
<< std::to_string(edge_label_id);
172+
auto neighbor = edges.GetNeighbor();
173+
sum += pagerank[1][neighbor] / outdegree[1][neighbor];
174+
edges.Next();
175+
}
176+
}
177+
178+
new_pagerank[1][v] =
179+
damping_factor * sum + (1.0 - damping_factor) / num_vertices;
180+
dst_vertex_iter.Next();
181+
}
95182
}
96183

97184
double diff = 0.0;
98-
for (const auto& kv : pagerank) {
99-
diff += std::abs(new_pagerank[kv.first] - kv.second);
185+
for (size_t j = 0; j < new_pagerank[0].size(); ++j) {
186+
diff += std::abs(new_pagerank[0][j] - pagerank[0][j]);
187+
}
188+
if (dst_to_src) {
189+
for (size_t j = 0; j < new_pagerank[1].size(); ++j) {
190+
diff += std::abs(new_pagerank[1][j] - pagerank[1][j]);
191+
}
100192
}
101193

102194
if (diff < epsilon) {
@@ -108,28 +200,23 @@ results::CollectiveResults PageRank::Query(const GraphDBSession& sess,
108200

109201
results::CollectiveResults results;
110202

111-
for (auto kv : pagerank) {
112-
int64_t oid_ = txn.GetVertexId(vertex_label_id, kv.first).AsInt64();
113-
auto result = results.add_results();
114-
result->mutable_record()
115-
->add_columns()
116-
->mutable_entry()
117-
->mutable_element()
118-
->mutable_object()
119-
->set_str(vertex_label);
120-
result->mutable_record()
121-
->add_columns()
122-
->mutable_entry()
123-
->mutable_element()
124-
->mutable_object()
125-
->set_i64(oid_);
126-
result->mutable_record()
127-
->add_columns()
128-
->mutable_entry()
129-
->mutable_element()
130-
->mutable_object()
131-
->set_f64(kv.second);
203+
std::vector<std::tuple<label_t, vid_t, double>> final_pagerank(num_vertices);
204+
for (size_t i = 0; i < pagerank[0].size(); ++i) {
205+
final_pagerank[i] = std::make_tuple(src_vertex_label_id, i, pagerank[0][i]);
132206
}
207+
if (dst_to_src) {
208+
for (size_t i = 0; i < pagerank[1].size(); ++i) {
209+
final_pagerank[i + num_src_vertices] =
210+
std::make_tuple(dst_vertex_label_id, i, pagerank[1][i]);
211+
}
212+
}
213+
std::sort(final_pagerank.begin(), final_pagerank.end(),
214+
[](const std::tuple<label_t, vid_t, double>& a,
215+
const std::tuple<label_t, vid_t, double>& b) {
216+
return std::get<2>(a) > std::get<2>(b);
217+
});
218+
219+
write_result(txn, results, final_pagerank, result_limit);
133220

134221
txn.Commit();
135222
return results;

flex/engines/graph_db/app/builtin/pagerank.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,22 @@
1515

1616
#ifndef ENGINES_GRAPH_DB_APP_BUILDIN_PAGERANK_H_
1717
#define ENGINES_GRAPH_DB_APP_BUILDIN_PAGERANK_H_
18+
1819
#include "flex/engines/graph_db/database/graph_db_session.h"
1920
#include "flex/engines/hqps_db/app/interactive_app_base.h"
2021

2122
namespace gs {
22-
class PageRank
23-
: public CypherReadAppBase<std::string, std::string, double, int, double> {
23+
class PageRank : public CypherReadAppBase<std::string, std::string, std::string,
24+
double, int32_t, double, int32_t> {
2425
public:
2526
PageRank() {}
2627
results::CollectiveResults Query(const GraphDBSession& sess,
27-
std::string vertex_label,
28+
std::string src_vertex_label,
29+
std::string dst_vertex_label,
2830
std::string edge_label,
29-
double damping_factor, int max_iterations,
30-
double epsilon);
31+
double damping_factor,
32+
int32_t max_iterations, double epsilon,
33+
int32_t result_limit) override;
3134
};
3235

3336
class PageRankFactory : public AppFactoryBase {

flex/engines/graph_db/database/read_transaction.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,20 @@ ReadTransaction::edge_iterator ReadTransaction::GetInEdgeIterator(
137137
graph_.get_incoming_edges(label, u, neighbor_label, edge_label)};
138138
}
139139

140+
size_t ReadTransaction::GetOutDegree(label_t label, vid_t u,
141+
label_t neighbor_label,
142+
label_t edge_label) const {
143+
return graph_.get_outgoing_edges(label, u, neighbor_label, edge_label)
144+
->size();
145+
}
146+
147+
size_t ReadTransaction::GetInDegree(label_t label, vid_t u,
148+
label_t neighbor_label,
149+
label_t edge_label) const {
150+
return graph_.get_incoming_edges(label, u, neighbor_label, edge_label)
151+
->size();
152+
}
153+
140154
void ReadTransaction::release() {
141155
if (timestamp_ != std::numeric_limits<timestamp_t>::max()) {
142156
vm_.release_read_timestamp();

flex/engines/graph_db/database/read_transaction.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,12 @@ class ReadTransaction {
491491
label_t neighbor_label,
492492
label_t edge_label) const;
493493

494+
size_t GetOutDegree(label_t label, vid_t u, label_t neighbor_label,
495+
label_t edge_label) const;
496+
497+
size_t GetInDegree(label_t label, vid_t u, label_t neighbor_label,
498+
label_t edge_label) const;
499+
494500
template <typename EDATA_T>
495501
AdjListView<EDATA_T> GetOutgoingEdges(label_t v_label, vid_t v,
496502
label_t neighbor_label,

flex/interactive/sdk/python/gs_interactive/tests/test_robustness.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,12 @@ def test_builtin_procedure(interactive_session, neo4j_session, create_modern_gra
260260
create_modern_graph,
261261
"pagerank",
262262
'"person"',
263+
'"person"',
263264
'"knows"',
264265
"0.85",
265266
"100",
266267
"0.000001",
268+
"10",
267269
)
268270

269271
call_procedure(

flex/storages/metadata/graph_meta_store.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,15 @@ const std::vector<PluginMeta>& get_builtin_plugin_metas() {
8080
pagerank.type = "cypher";
8181
pagerank.creation_time = GetCurrentTimeStamp();
8282
pagerank.update_time = GetCurrentTimeStamp();
83-
pagerank.params.push_back({"vertex_label", PropertyType::kString, true});
83+
pagerank.params.push_back(
84+
{"src_vertex_label", PropertyType::kString, true});
85+
pagerank.params.push_back(
86+
{"dst_vertex_label", PropertyType::kString, true});
8487
pagerank.params.push_back({"edge_label", PropertyType::kString, true});
8588
pagerank.params.push_back({"damping_factor", PropertyType::kDouble, false});
8689
pagerank.params.push_back({"max_iterations", PropertyType::kInt32, false});
8790
pagerank.params.push_back({"epsilon", PropertyType::kDouble, false});
91+
pagerank.params.push_back({"result_limit", PropertyType::kInt32, false});
8892
pagerank.returns.push_back({"label_name", PropertyType::kString});
8993
pagerank.returns.push_back({"vertex_oid", PropertyType::kInt64});
9094
pagerank.returns.push_back({"pagerank", PropertyType::kDouble});

0 commit comments

Comments
 (0)