Skip to content

Commit d73c54a

Browse files
committed
Generic dfs template
1 parent 2c41eb9 commit d73c54a

File tree

2 files changed

+151
-84
lines changed

2 files changed

+151
-84
lines changed

cp-algo/graph/dfs.hpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#ifndef CP_ALGO_GRAPH_DFS_HPP
2+
#define CP_ALGO_GRAPH_DFS_HPP
3+
#include "base.hpp"
4+
#include <variant>
5+
#include <stack>
6+
namespace cp_algo::graph {
7+
enum node_state { unvisited, visiting, visited, blocked };
8+
9+
template<graph_type graph>
10+
struct dfs_context {
11+
big_vector<node_state> state;
12+
graph const* g;
13+
14+
dfs_context(graph const& g): state(g.n()), g(&g) {}
15+
16+
// Called when first entering a node
17+
void on_enter(node_index) {}
18+
19+
// Called when discovering a tree edge (v->u is a tree edge, u is unvisited)
20+
void on_tree_edge(node_index, node_index, edge_index) {}
21+
22+
// Called after returning from a child via tree edge
23+
void on_return_from_child(node_index, node_index, edge_index) {}
24+
25+
// Called when encountering a back edge (v->u, u is visiting)
26+
void on_back_edge(node_index, node_index, edge_index) {}
27+
28+
// Called when encountering a forward/cross edge (v->u, u is visited)
29+
void on_forward_cross_edge(node_index, node_index, edge_index) {}
30+
31+
// Called when exiting a node (all edges processed)
32+
void on_exit(node_index) {}
33+
};
34+
35+
template<template<typename> class Context, graph_type graph>
36+
Context<graph>& dfs(Context<graph>& context) {
37+
graph const& g = *context.g;
38+
auto const& adj = g.incidence_lists();
39+
struct frame {
40+
node_index v;
41+
[[no_unique_address]] std::conditional_t<
42+
undirected_graph_type<graph>,
43+
edge_index, std::monostate> ep;
44+
int sv; // edge index in stack_union
45+
enum { INIT, PROCESS_EDGES, HANDLE_CHILD } state;
46+
};
47+
48+
std::stack<frame> dfs_stack;
49+
50+
for (auto root: g.nodes()) {
51+
if (context.state[root] != unvisited) continue;
52+
53+
if constexpr (undirected_graph_type<graph>) {
54+
dfs_stack.push({root, -1, 0, frame::INIT});
55+
} else {
56+
dfs_stack.push({root, {}, 0, frame::INIT});
57+
}
58+
59+
while (!dfs_stack.empty()) {
60+
auto& f = dfs_stack.top();
61+
62+
if (f.state == frame::INIT) {
63+
context.state[f.v] = visiting;
64+
context.on_enter(f.v);
65+
f.sv = adj.head[f.v];
66+
f.state = frame::PROCESS_EDGES;
67+
}
68+
69+
if (f.state == frame::HANDLE_CHILD) {
70+
auto e = adj.data[f.sv];
71+
f.sv = adj.next[f.sv];
72+
node_index u = g.edge(e).to;
73+
context.on_return_from_child(f.v, u, e);
74+
f.state = frame::PROCESS_EDGES;
75+
}
76+
77+
// PROCESS_EDGES
78+
bool found_child = false;
79+
while (f.sv != 0) {
80+
auto e = adj.data[f.sv];
81+
82+
if constexpr (undirected_graph_type<graph>) {
83+
if (f.ep == graph::opposite_idx(e)) {
84+
f.sv = adj.next[f.sv];
85+
continue;
86+
}
87+
}
88+
89+
node_index u = g.edge(e).to;
90+
if (context.state[u] == unvisited) {
91+
context.on_tree_edge(f.v, u, e);
92+
f.state = frame::HANDLE_CHILD;
93+
if constexpr (undirected_graph_type<graph>) {
94+
dfs_stack.push({u, e, 0, frame::INIT});
95+
} else {
96+
dfs_stack.push({u, {}, 0, frame::INIT});
97+
}
98+
found_child = true;
99+
break;
100+
} else if (context.state[u] == visiting) {
101+
context.on_back_edge(f.v, u, e);
102+
} else if (context.state[u] == visited) {
103+
context.on_forward_cross_edge(f.v, u, e);
104+
}
105+
f.sv = adj.next[f.sv];
106+
}
107+
108+
if (found_child) continue;
109+
110+
// All edges processed
111+
context.state[f.v] = visited;
112+
context.on_exit(f.v);
113+
dfs_stack.pop();
114+
}
115+
}
116+
return context;
117+
}
118+
}
119+
#endif // CP_ALGO_GRAPH_DFS_HPP

cp-algo/graph/tarjan.hpp

Lines changed: 32 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,67 @@
11
#ifndef CP_ALGO_GRAPH_TARJAN_HPP
22
#define CP_ALGO_GRAPH_TARJAN_HPP
3+
#include "dfs.hpp"
34
#include "base.hpp"
45
#include "../structures/csr.hpp"
56
#include <algorithm>
67
#include <cassert>
78
#include <stack>
89
namespace cp_algo::graph {
9-
enum node_state { unvisited, visiting, visited, blocked };
1010
template<graph_type graph>
11-
struct tarjan_context {
11+
struct tarjan_context: dfs_context<graph> {
1212
big_vector<int> tin, low;
13-
big_vector<node_state> state;
1413
std::stack<int> stack;
15-
graph const* g;
1614
int timer;
1715
structures::csr<node_index> components;
18-
tarjan_context(graph const& g):
19-
tin(g.n()), low(g.n()), state(g.n()), g(&g), timer(0) {
16+
17+
tarjan_context(graph const& g): dfs_context<graph>(g),
18+
tin(g.n()), low(g.n()), timer(0) {
2019
components.reserve_data(g.n());
2120
}
2221

23-
void on_tree_edge(node_index, edge_index) {}
24-
void on_exit(node_index) {}
22+
void on_enter(node_index v) {
23+
tin[v] = low[v] = timer++;
24+
stack.push(v);
25+
}
26+
27+
void on_return_from_child(node_index v, node_index u, edge_index) {
28+
low[v] = std::min(low[v], low[u]);
29+
}
30+
31+
void on_back_edge(node_index v, node_index u, edge_index) {
32+
low[v] = std::min(low[v], tin[u]);
33+
}
34+
35+
void on_forward_cross_edge(node_index v, node_index u, edge_index) {
36+
low[v] = std::min(low[v], tin[u]);
37+
}
38+
39+
void on_tree_edge_processed(node_index, node_index, edge_index) {}
2540

2641
void collect(node_index v) {
2742
components.new_row();
2843
node_index u;
2944
do {
3045
u = stack.top();
3146
stack.pop();
32-
state[u] = blocked;
47+
this->state[u] = blocked;
3348
components.push(u);
3449
} while(u != v);
3550
}
3651
};
52+
3753
template<template<typename> class Context, graph_type graph>
3854
auto tarjan(graph const& g) {
3955
Context<graph> context(g);
40-
41-
auto const& adj = g.incidence_lists();
42-
43-
struct frame {
44-
node_index v;
45-
edge_index ep;
46-
int sv; // edge index in stack_union
47-
enum { INIT, PROCESS_EDGES, HANDLE_CHILD } state;
48-
};
49-
50-
std::stack<frame> dfs_stack;
51-
52-
for (auto root: g.nodes()) {
53-
if (context.state[root] != unvisited) continue;
54-
55-
dfs_stack.push({root, -1, 0, frame::INIT});
56-
57-
while (!dfs_stack.empty()) {
58-
auto& f = dfs_stack.top();
59-
60-
if (f.state == frame::INIT) {
61-
context.state[f.v] = visiting;
62-
context.tin[f.v] = context.low[f.v] = context.timer++;
63-
context.stack.push(f.v);
64-
f.sv = adj.head[f.v];
65-
f.state = frame::PROCESS_EDGES;
66-
}
67-
68-
if (f.state == frame::HANDLE_CHILD) {
69-
auto e = adj.data[f.sv];
70-
f.sv = adj.next[f.sv];
71-
node_index u = g.edge(e).to;
72-
context.low[f.v] = std::min(context.low[f.v], context.low[u]);
73-
context.on_tree_edge(f.v, u);
74-
f.state = frame::PROCESS_EDGES;
75-
}
76-
77-
// PROCESS_EDGES
78-
bool found_child = false;
79-
while (f.sv != 0) {
80-
auto e = adj.data[f.sv];
81-
82-
if constexpr (undirected_graph_type<graph>) {
83-
if (f.ep == graph::opposite_idx(e)) {
84-
f.sv = adj.next[f.sv];
85-
continue;
86-
}
87-
}
88-
89-
node_index u = g.edge(e).to;
90-
if (context.state[u] == unvisited) {
91-
f.state = frame::HANDLE_CHILD;
92-
dfs_stack.push({u, e, 0, frame::INIT});
93-
found_child = true;
94-
break;
95-
} else if (context.state[u] != blocked) {
96-
context.low[f.v] = std::min(context.low[f.v], context.tin[u]);
97-
}
98-
f.sv = adj.next[f.sv];
99-
}
100-
101-
if (found_child) continue;
102-
103-
// All edges processed
104-
context.state[f.v] = visited;
105-
context.on_exit(f.v);
106-
dfs_stack.pop();
107-
}
108-
}
109-
return context.components;
56+
return dfs(context).components;
11057
}
11158
template<graph_type graph>
11259
struct exit_context: tarjan_context<graph> {
113-
using base = tarjan_context<graph>;
114-
using base::base;
60+
using tarjan_context<graph>::tarjan_context;
11561

11662
void on_exit(node_index v) {
117-
if (base::low[v] == base::tin[v]) {
118-
base::collect(v);
63+
if (this->low[v] == this->tin[v]) {
64+
this->collect(v);
11965
}
12066
}
12167
};
@@ -136,7 +82,9 @@ namespace cp_algo::graph {
13682
struct bcc_context: tarjan_context<graph> {
13783
using base = tarjan_context<graph>;
13884
using base::base;
139-
void on_tree_edge(node_index v, node_index u) {
85+
86+
void on_return_from_child(node_index v, node_index u, edge_index e) {
87+
base::on_return_from_child(v, u, e);
14088
if (base::low[u] >= base::tin[v]) {
14189
base::collect(u);
14290
base::components.push(v);

0 commit comments

Comments
 (0)