@@ -351,6 +351,36 @@ std::unique_ptr<katana::PropertyGraph>
351351katana::PropertyGraph::MakeProjectedGraph (
352352 const PropertyGraph& pg, const std::vector<std::string>& node_types,
353353 const std::vector<std::string>& edge_types) {
354+ auto ret = MakeProjectedGraph (
355+ pg, node_types.empty () ? std::nullopt : std::make_optional (node_types),
356+ edge_types.empty () ? std::nullopt : std::make_optional (edge_types));
357+ KATANA_LOG_VASSERT (ret.has_value (), " {}" , ret.error ());
358+ return std::move (ret.value ());
359+ }
360+
361+ katana::Result<std::unique_ptr<katana::PropertyGraph>>
362+ katana::PropertyGraph::MakeProjectedGraph (
363+ const PropertyGraph& pg, std::optional<std::vector<std::string>> node_types,
364+ std::optional<std::vector<std::string>> edge_types) {
365+ std::optional<SetOfEntityTypeIDs> node_type_ids;
366+ if (node_types) {
367+ node_type_ids = KATANA_CHECKED (
368+ pg.GetNodeTypeManager ().GetEntityTypeIDs (node_types.value ()));
369+ }
370+ std::optional<SetOfEntityTypeIDs> edge_type_ids;
371+ if (edge_types) {
372+ edge_type_ids = KATANA_CHECKED (
373+ pg.GetEdgeTypeManager ().GetEntityTypeIDs (edge_types.value ()));
374+ }
375+ return MakeProjectedGraph (pg, node_type_ids, edge_type_ids);
376+ }
377+
378+ // / Make a projected graph from a property graph. Shares state with
379+ // / the original graph.
380+ katana::Result<std::unique_ptr<katana::PropertyGraph>>
381+ katana::PropertyGraph::MakeProjectedGraph (
382+ const PropertyGraph& pg, std::optional<SetOfEntityTypeIDs> node_types,
383+ std::optional<SetOfEntityTypeIDs> edge_types) {
354384 const auto & topology = pg.topology ();
355385 if (topology.empty ()) {
356386 return std::make_unique<PropertyGraph>();
@@ -366,7 +396,7 @@ katana::PropertyGraph::MakeProjectedGraph(
366396 NUMAArray<Node> original_to_projected_nodes_mapping;
367397 original_to_projected_nodes_mapping.allocateInterleaved (topology.NumNodes ());
368398
369- if (node_types. empty () ) {
399+ if (! node_types) {
370400 num_new_nodes = topology.NumNodes ();
371401 // set all nodes
372402 katana::do_all (katana::iterate (topology.Nodes ()), [&](auto src) {
@@ -378,21 +408,14 @@ katana::PropertyGraph::MakeProjectedGraph(
378408 original_to_projected_nodes_mapping.begin (),
379409 original_to_projected_nodes_mapping.end (), Node{0 });
380410
381- std::set<katana::EntityTypeID> node_entity_type_ids;
382-
383- for (auto node_type : node_types) {
384- auto entity_type_id = pg.GetNodeEntityTypeID (node_type);
385- node_entity_type_ids.insert (entity_type_id);
386- }
387-
388411 katana::GAccumulator<uint32_t > accum_num_new_nodes;
389412
390413 katana::do_all (katana::iterate (topology.Nodes ()), [&](auto src) {
391- for (auto type : node_entity_type_ids ) {
414+ for (auto type : node_types. value () ) {
392415 if (pg.DoesNodeHaveType (src, type)) {
393416 accum_num_new_nodes += 1 ;
394417 bitset_nodes.set (src);
395- // this sets the correspondign entry in the array to 1
418+ // this sets the corresponding entry in the array to 1
396419 // will perform a prefix sum on this array later on
397420 original_to_projected_nodes_mapping[src] = 1 ;
398421 return ;
@@ -444,7 +467,7 @@ katana::PropertyGraph::MakeProjectedGraph(
444467 // initializes the edge-index array to all zeros
445468 katana::ParallelSTL::fill (out_indices.begin (), out_indices.end (), Edge{0 });
446469
447- if (edge_types. empty () ) {
470+ if (! edge_types) {
448471 katana::GAccumulator<uint32_t > accum_num_new_edges;
449472 // set all edges incident to projected nodes
450473 katana::do_all (
@@ -464,13 +487,6 @@ katana::PropertyGraph::MakeProjectedGraph(
464487
465488 num_new_edges = accum_num_new_edges.reduce ();
466489 } else {
467- std::set<katana::EntityTypeID> edge_entity_type_ids;
468-
469- for (auto edge_type : edge_types) {
470- auto entity_type_id = pg.GetEdgeEntityTypeID (edge_type);
471- edge_entity_type_ids.insert (entity_type_id);
472- }
473-
474490 katana::GAccumulator<uint32_t > accum_num_new_edges;
475491
476492 katana::do_all (
@@ -481,7 +497,7 @@ katana::PropertyGraph::MakeProjectedGraph(
481497 for (Edge e : topology.OutEdges (old_src)) {
482498 auto dest = topology.OutEdgeDst (e);
483499 if (bitset_nodes.test (dest)) {
484- for (auto type : edge_entity_type_ids ) {
500+ for (auto type : edge_types. value () ) {
485501 if (pg.DoesEdgeHaveTypeFromTopoIndex (e, type)) {
486502 accum_num_new_edges += 1 ;
487503 bitset_edges.set (e);
0 commit comments