Skip to content

Commit 153e1d5

Browse files
Chilleesimonlindholm
authored andcommitted
Added directedMST/Minimum arborescence code (#110)
1 parent ed0c783 commit 153e1d5

File tree

4 files changed

+236
-3
lines changed

4 files changed

+236
-3
lines changed

content/data-structures/UnionFind.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ struct UF {
1414
bool same_set(int a, int b) { return find(a) == find(b); }
1515
int size(int x) { return -e[find(x)]; }
1616
int find(int x) { return e[x] < 0 ? x : e[x] = find(e[x]); }
17-
void join(int a, int b) {
17+
bool join(int a, int b) {
1818
a = find(a), b = find(b);
19-
if (a == b) return;
19+
if (a == b) return false;
2020
if (e[a] > e[b]) swap(a, b);
2121
e[a] += e[b]; e[b] = a;
22+
return true;
2223
}
2324
};

content/graph/DirectedMST.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/**
2+
* Author: chilli, Takanori MAEHARA
3+
* Date: 2019-05-10
4+
* License: CC0
5+
* Source: https://github.com/spaghetti-source/algorithm/blob/master/graph/arborescence.cc
6+
* Description: Edmonds' algorithm for finding the weight of the minimum spanning
7+
* tree/arborescence of a directed graph, given a root node. If no MST exists, returns -1.
8+
* Time: O(E \log V)
9+
* Status: Fuzz-tested, also tested on NWERC 2018 fastestspeedrun
10+
*/
11+
#pragma once
12+
13+
#include "../data-structures/UnionFind.h"
14+
15+
struct Edge { int a, b; ll w; };
16+
struct Node { /// lazy skew heap node
17+
Edge key;
18+
Node *l, *r;
19+
ll delta;
20+
void prop() {
21+
key.w += delta;
22+
if (l) l->delta += delta;
23+
if (r) r->delta += delta;
24+
delta = 0;
25+
}
26+
Edge top() { prop(); return key; }
27+
};
28+
Node *merge(Node *a, Node *b) {
29+
if (!a || !b) return a ?: b;
30+
a->prop(), b->prop();
31+
if (a->key.w > b->key.w) swap(a, b);
32+
swap(a->l, (a->r = merge(b, a->r)));
33+
return a;
34+
}
35+
void pop(Node*& a) { a->prop(); a = merge(a->l, a->r); }
36+
37+
ll dmst(int n, int r, vector<Edge>& g) {
38+
UF uf(n);
39+
vector<Node*> heap(n);
40+
trav(e, g) heap[e.b] = merge(heap[e.b], new Node{e});
41+
ll res = 0;
42+
vi seen(n, -1), path(n);
43+
seen[r] = r;
44+
rep(s,0,n) {
45+
int u = s, qi = 0, w;
46+
while (seen[u] < 0) {
47+
path[qi++] = u, seen[u] = s;
48+
if (!heap[u]) return -1;
49+
Edge e = heap[u]->top();
50+
heap[u]->delta -= e.w, pop(heap[u]);
51+
res += e.w, u = uf.find(e.a);
52+
if (seen[u] == s) {
53+
Node* cyc = 0;
54+
do cyc = merge(cyc, heap[w = path[--qi]]);
55+
while (uf.join(u, w));
56+
u = uf.find(u);
57+
heap[u] = cyc, seen[u] = -1;
58+
}
59+
}
60+
}
61+
return res;
62+
}

content/graph/chapter.tex

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ \section{Trees}
3939
\kactlimport{CompressTree.h}
4040
\kactlimport{HLD.h}
4141
\kactlimport{LinkCutTree.h}
42+
\kactlimport{DirectedMST.h}
4243
\kactlimport{MatrixTree.h}
43-
\columnbreak
44+
\hardcolumnbreak

fuzz-tests/graph/DirectedMST.cpp

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#include <bits/stdc++.h>
2+
using namespace std;
3+
4+
#define rep(i, a, b) for(int i = a; i < int(b); ++i)
5+
#define trav(a, v) for(auto& a : v)
6+
#define all(x) x.begin(), x.end()
7+
#define sz(x) (int)(x).size()
8+
9+
typedef long long ll;
10+
typedef pair<int, int> pii;
11+
typedef vector<int> vi;
12+
13+
struct Bumpalloc {
14+
char buf[450 << 20];
15+
size_t bufp;
16+
void* alloc(size_t s) {
17+
assert(s < bufp);
18+
return (void*)&buf[bufp -= s];
19+
}
20+
Bumpalloc() { reset(); }
21+
22+
template<class T> T* operator=(T&& x) {
23+
T* r = (T*)alloc(sizeof(T));
24+
new(r) T(move(x));
25+
return r;
26+
}
27+
void reset() { bufp = sizeof buf; }
28+
} bumpalloc;
29+
30+
// When not testing perf, we don't want to leak memory
31+
#ifndef TEST_PERF
32+
#define new bumpalloc =
33+
#endif
34+
#include "../../content/graph/DirectedMST.h"
35+
#ifndef TEST_PERF
36+
#undef new
37+
#endif
38+
39+
namespace mit {
40+
41+
#define N 110000
42+
#define M 110000
43+
#define inf 2000000000
44+
45+
struct edg {
46+
int u, v;
47+
int cost;
48+
} E[M], E_copy[M];
49+
50+
int In[N], ID[N], vis[N], pre[N];
51+
52+
// edges pointed from root.
53+
int Directed_MST(int root, int NV, int NE) {
54+
for (int i = 0; i < NE; i++)
55+
E_copy[i] = E[i];
56+
int ret = 0;
57+
int u, v;
58+
while (true) {
59+
rep(i,0,NV) In[i] = inf;
60+
rep(i,0,NE) {
61+
u = E_copy[i].u;
62+
v = E_copy[i].v;
63+
if(E_copy[i].cost < In[v] && u != v) {
64+
In[v] = E_copy[i].cost;
65+
pre[v] = u;
66+
}
67+
}
68+
rep(i,0,NV) {
69+
if(i == root) continue;
70+
if(In[i] == inf) return -1; // no solution
71+
}
72+
73+
int cnt = 0;
74+
rep(i,0,NV) {
75+
ID[i] = -1;
76+
vis[i] = -1;
77+
}
78+
In[root] = 0;
79+
80+
rep(i,0,NV) {
81+
ret += In[i];
82+
int v = i;
83+
while(vis[v] != i && ID[v] == -1 && v != root) {
84+
vis[v] = i;
85+
v = pre[v];
86+
}
87+
if(v != root && ID[v] == -1) {
88+
for(u = pre[v]; u != v; u = pre[u]) {
89+
ID[u] = cnt;
90+
}
91+
ID[v] = cnt++;
92+
}
93+
}
94+
if(cnt == 0) break;
95+
rep(i,0,NV) {
96+
if(ID[i] == -1) ID[i] = cnt++;
97+
}
98+
rep(i,0,NE) {
99+
v = E_copy[i].v;
100+
E_copy[i].u = ID[E_copy[i].u];
101+
E_copy[i].v = ID[E_copy[i].v];
102+
if(E_copy[i].u != E_copy[i].v) {
103+
E_copy[i].cost -= In[v];
104+
}
105+
}
106+
NV = cnt;
107+
root = ID[root];
108+
}
109+
return ret;
110+
}
111+
}
112+
113+
int adj[105][105];
114+
int main() {
115+
rep(it,0,100000) {
116+
bumpalloc.reset();
117+
int n = (rand()%20)+1;
118+
int density = rand() % 101;
119+
int r = rand()%n;
120+
int cnt = 0;
121+
vector<Edge> edges;
122+
rep(i,0,n)
123+
rep(j,0,n){
124+
if (i==j) continue;
125+
if (rand() % 100 >= density) continue;
126+
int weight = rand()%100;
127+
mit::E[cnt++] = {i,j, weight};
128+
edges.push_back({i,j,weight});
129+
adj[i][j] = weight;
130+
}
131+
ll ans1 = mit::Directed_MST(r, n, cnt);
132+
ll ans2 = dmst(n, r, edges);
133+
assert(ans1 == ans2);
134+
// For verifying a reconstruction:
135+
/*
136+
if (ans1 != -1) {
137+
vi par = pa.second;
138+
if (debug) {
139+
cout << "r = " << r << endl;
140+
trav(x, par) cout << x << ' ';
141+
cout << endl;
142+
trav(e, edges) {
143+
cout << e.a << ' ' << e.b << ' ' << e.w << endl;
144+
}
145+
}
146+
ll sum = 0;
147+
vector<vi> ch(n);
148+
rep(i,0,n) {
149+
if (i == r) assert(par[i] == -1);
150+
else {
151+
assert(par[i] != -1);
152+
sum += adj[par[i]][i];
153+
ch[par[i]].push_back(i);
154+
}
155+
}
156+
assert(sum == ans1);
157+
vi seen(n), q = {r};
158+
rep(qi,0,sz(q)) {
159+
int s = q[qi];
160+
if (!seen[s]++)
161+
trav(x, ch[s]) q.push_back(x);
162+
}
163+
assert(count(all(seen), 0) == 0);
164+
}
165+
*/
166+
}
167+
cout<<"Tests passed!"<<endl;
168+
return 0;
169+
}

0 commit comments

Comments
 (0)