Skip to content

Commit 3908bcf

Browse files
Run TopologicalSort in O(V+E) instead of O(V^2) (#144)
1 parent 8757b27 commit 3908bcf

File tree

2 files changed

+216
-58
lines changed

2 files changed

+216
-58
lines changed

dag.go

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ func TopologicalSort[K comparable, T any](g Graph[K, T]) ([]K, error) {
2727
return nil, fmt.Errorf("failed to get graph order: %w", err)
2828
}
2929

30+
adjacencyMap, err := g.AdjacencyMap()
31+
if err != nil {
32+
return nil, fmt.Errorf("failed to get adjacency map: %w", err)
33+
}
34+
3035
predecessorMap, err := g.PredecessorMap()
3136
if err != nil {
3237
return nil, fmt.Errorf("failed to get predecessor map: %w", err)
@@ -37,28 +42,28 @@ func TopologicalSort[K comparable, T any](g Graph[K, T]) ([]K, error) {
3742
for vertex, predecessors := range predecessorMap {
3843
if len(predecessors) == 0 {
3944
queue = append(queue, vertex)
45+
delete(predecessorMap, vertex)
4046
}
4147
}
4248

4349
order := make([]K, 0, gOrder)
44-
visited := make(map[K]struct{}, gOrder)
4550

4651
for len(queue) > 0 {
4752
currentVertex := queue[0]
4853
queue = queue[1:]
4954

50-
if _, ok := visited[currentVertex]; ok {
51-
continue
52-
}
53-
5455
order = append(order, currentVertex)
55-
visited[currentVertex] = struct{}{}
5656

57-
for vertex, predecessors := range predecessorMap {
57+
edgeMap := adjacencyMap[currentVertex]
58+
59+
for target := range edgeMap {
60+
61+
predecessors := predecessorMap[target]
5862
delete(predecessors, currentVertex)
5963

6064
if len(predecessors) == 0 {
61-
queue = append(queue, vertex)
65+
queue = append(queue, target)
66+
delete(predecessorMap, target)
6267
}
6368
}
6469
}
@@ -78,23 +83,31 @@ func StableTopologicalSort[K comparable, T any](g Graph[K, T], less func(K, K) b
7883
return nil, fmt.Errorf("topological sort cannot be computed on undirected graph")
7984
}
8085

86+
gOrder, err := g.Order()
87+
if err != nil {
88+
return nil, fmt.Errorf("failed to get graph order: %w", err)
89+
}
90+
91+
adjacencyMap, err := g.AdjacencyMap()
92+
if err != nil {
93+
return nil, fmt.Errorf("failed to get adjacency map: %w", err)
94+
}
95+
8196
predecessorMap, err := g.PredecessorMap()
8297
if err != nil {
8398
return nil, fmt.Errorf("failed to get predecessor map: %w", err)
8499
}
85100

86101
queue := make([]K, 0)
87-
queued := make(map[K]struct{})
88102

89103
for vertex, predecessors := range predecessorMap {
90104
if len(predecessors) == 0 {
91105
queue = append(queue, vertex)
92-
queued[vertex] = struct{}{}
106+
delete(predecessorMap, vertex)
93107
}
94108
}
95109

96-
order := make([]K, 0, len(predecessorMap))
97-
visited := make(map[K]struct{})
110+
order := make([]K, 0, gOrder)
98111

99112
sort.Slice(queue, func(i, j int) bool {
100113
return less(queue[i], queue[j])
@@ -104,28 +117,21 @@ func StableTopologicalSort[K comparable, T any](g Graph[K, T], less func(K, K) b
104117
currentVertex := queue[0]
105118
queue = queue[1:]
106119

107-
if _, ok := visited[currentVertex]; ok {
108-
continue
109-
}
110-
111120
order = append(order, currentVertex)
112-
visited[currentVertex] = struct{}{}
113121

114122
frontier := make([]K, 0)
115123

116-
for vertex, predecessors := range predecessorMap {
117-
delete(predecessors, currentVertex)
124+
edgeMap := adjacencyMap[currentVertex]
118125

119-
if len(predecessors) != 0 {
120-
continue
121-
}
126+
for target := range edgeMap {
122127

123-
if _, ok := queued[vertex]; ok {
124-
continue
125-
}
128+
predecessors := predecessorMap[target]
129+
delete(predecessors, currentVertex)
126130

127-
frontier = append(frontier, vertex)
128-
queued[vertex] = struct{}{}
131+
if len(predecessors) == 0 {
132+
frontier = append(frontier, target)
133+
delete(predecessorMap, target)
134+
}
129135
}
130136

131137
sort.Slice(frontier, func(i, j int) bool {
@@ -135,11 +141,6 @@ func StableTopologicalSort[K comparable, T any](g Graph[K, T], less func(K, K) b
135141
queue = append(queue, frontier...)
136142
}
137143

138-
gOrder, err := g.Order()
139-
if err != nil {
140-
return nil, fmt.Errorf("failed to get graph order: %w", err)
141-
}
142-
143144
if len(order) != gOrder {
144145
return nil, errors.New("topological sort cannot be computed on graph with cycles")
145146
}

dag_test.go

Lines changed: 183 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package graph
22

33
import (
44
"fmt"
5+
"math/rand"
56
"testing"
7+
"time"
68
)
79

810
func TestDirectedTopologicalSort(t *testing.T) {
@@ -25,6 +27,17 @@ func TestDirectedTopologicalSort(t *testing.T) {
2527
},
2628
expectedOrder: []int{1, 2, 3, 4, 5},
2729
},
30+
"graph with many possible topological orders": {
31+
vertices: []int{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60},
32+
edges: []Edge[int]{
33+
{Source: 1, Target: 10},
34+
{Source: 2, Target: 20},
35+
{Source: 3, Target: 30},
36+
{Source: 4, Target: 40},
37+
{Source: 5, Target: 50},
38+
{Source: 6, Target: 60},
39+
},
40+
},
2841
"graph with cycle": {
2942
vertices: []int{1, 2, 3},
3043
edges: []Edge[int]{
@@ -39,14 +52,9 @@ func TestDirectedTopologicalSort(t *testing.T) {
3952
for name, test := range tests {
4053
graph := New(IntHash, Directed())
4154

42-
for _, vertex := range test.vertices {
43-
_ = graph.AddVertex(vertex)
44-
}
45-
46-
for _, edge := range test.edges {
47-
if err := graph.AddEdge(edge.Source, edge.Target, EdgeWeight(edge.Properties.Weight)); err != nil {
48-
t.Fatalf("%s: failed to add edge: %s", name, err.Error())
49-
}
55+
err := buildGraph(&graph, test.vertices, test.edges)
56+
if err != nil {
57+
t.Fatalf("%s: failed to construct graph: %s", name, err.Error())
5058
}
5159

5260
order, err := TopologicalSort(graph)
@@ -59,8 +67,17 @@ func TestDirectedTopologicalSort(t *testing.T) {
5967
continue
6068
}
6169

62-
if len(order) != len(test.expectedOrder) {
63-
t.Errorf("%s: order length expectancy doesn't match: expected %v, got %v", name, len(test.expectedOrder), len(order))
70+
if len(order) != len(test.vertices) {
71+
t.Errorf("%s: order length expectancy doesn't match: expected %v, got %v", name, len(test.vertices), len(order))
72+
}
73+
74+
if len(test.expectedOrder) <= 0 {
75+
76+
fmt.Println("topological sort", order)
77+
78+
if err := verifyTopologicalSort(graph, order); err != nil {
79+
t.Errorf("%s: invalid topological sort - %v", name, err)
80+
}
6481
}
6582

6683
for i, expectedVertex := range test.expectedOrder {
@@ -143,14 +160,9 @@ func TestDirectedStableTopologicalSort(t *testing.T) {
143160
for name, test := range tests {
144161
graph := New(IntHash, Directed())
145162

146-
for _, vertex := range test.vertices {
147-
_ = graph.AddVertex(vertex)
148-
}
149-
150-
for _, edge := range test.edges {
151-
if err := graph.AddEdge(edge.Source, edge.Target, EdgeWeight(edge.Properties.Weight)); err != nil {
152-
t.Fatalf("%s: failed to add edge: %s", name, err.Error())
153-
}
163+
err := buildGraph(&graph, test.vertices, test.edges)
164+
if err != nil {
165+
t.Fatalf("%s: failed to construct graph: %s", name, err.Error())
154166
}
155167

156168
order, err := StableTopologicalSort(graph, func(a, b int) bool {
@@ -246,14 +258,9 @@ func TestDirectedTransitiveReduction(t *testing.T) {
246258
for name, test := range tests {
247259
graph := New(StringHash, Directed())
248260

249-
for _, vertex := range test.vertices {
250-
_ = graph.AddVertex(vertex)
251-
}
252-
253-
for _, edge := range test.edges {
254-
if err := graph.AddEdge(edge.Source, edge.Target, EdgeWeight(edge.Properties.Weight)); err != nil {
255-
t.Fatalf("%s: failed to add edge: %s", name, err.Error())
256-
}
261+
err := buildGraph(&graph, test.vertices, test.edges)
262+
if err != nil {
263+
t.Fatalf("%s: failed to construct graph: %s", name, err.Error())
257264
}
258265

259266
reduction, err := TransitiveReduction(graph)
@@ -303,6 +310,103 @@ func TestUndirectedTransitiveReduction(t *testing.T) {
303310
}
304311
}
305312

313+
func TestVerifyTopologicalSort(t *testing.T) {
314+
tests := map[string]struct {
315+
vertices []int
316+
edges []Edge[int]
317+
invalidOrder []int
318+
}{
319+
"graph with 2 vertices": {
320+
vertices: []int{1, 2},
321+
edges: []Edge[int]{
322+
{Source: 1, Target: 2},
323+
},
324+
},
325+
"graph with 2 vertices - reversed": {
326+
vertices: []int{1, 2},
327+
edges: []Edge[int]{
328+
{Source: 2, Target: 1},
329+
},
330+
},
331+
"graph with 2 vertices - invalid": {
332+
vertices: []int{1, 2},
333+
edges: []Edge[int]{
334+
{Source: 1, Target: 2},
335+
},
336+
invalidOrder: []int{2, 1},
337+
},
338+
"graph with 3 vertices": {
339+
vertices: []int{1, 2, 3},
340+
edges: []Edge[int]{
341+
{Source: 1, Target: 2},
342+
{Source: 1, Target: 3},
343+
{Source: 2, Target: 3},
344+
},
345+
},
346+
"graph with 3 vertices - invalid": {
347+
vertices: []int{1, 2, 3},
348+
edges: []Edge[int]{
349+
{Source: 1, Target: 2},
350+
{Source: 1, Target: 3},
351+
{Source: 2, Target: 3},
352+
},
353+
invalidOrder: []int{1, 3, 2},
354+
},
355+
"graph with 5 vertices": {
356+
vertices: []int{1, 2, 3, 4, 5},
357+
edges: []Edge[int]{
358+
{Source: 1, Target: 2},
359+
{Source: 1, Target: 3},
360+
{Source: 2, Target: 3},
361+
{Source: 2, Target: 4},
362+
{Source: 2, Target: 5},
363+
{Source: 3, Target: 4},
364+
{Source: 4, Target: 5},
365+
},
366+
},
367+
"graph with many possible topological orders": {
368+
vertices: []int{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60},
369+
edges: []Edge[int]{
370+
{Source: 1, Target: 10},
371+
{Source: 2, Target: 20},
372+
{Source: 3, Target: 30},
373+
{Source: 4, Target: 40},
374+
{Source: 5, Target: 50},
375+
{Source: 6, Target: 60},
376+
},
377+
invalidOrder: []int{2, 3, 4, 5, 6, 10, 1, 20, 30, 40, 50, 60},
378+
},
379+
}
380+
381+
for name, test := range tests {
382+
graph := New[int, int](IntHash, Directed())
383+
384+
err := buildGraph(&graph, test.vertices, test.edges)
385+
if err != nil {
386+
t.Fatalf("%s: failed to construct graph: %s", name, err.Error())
387+
}
388+
389+
var order[] int
390+
391+
if len(test.invalidOrder) > 0 {
392+
order = test.invalidOrder
393+
} else {
394+
order, err = TopologicalSort(graph)
395+
if err != nil {
396+
t.Fatalf("%s: error failed to produce topological sort: %v)", name, err)
397+
}
398+
}
399+
400+
err = verifyTopologicalSort(graph, order)
401+
402+
shouldFail := len(test.invalidOrder) > 0
403+
404+
if shouldFail != (err != nil) {
405+
t.Errorf("%s: error expectancy doesn't match: expected %v, got %v (error: %v)", name, shouldFail, err != nil, err)
406+
}
407+
}
408+
}
409+
306410
func slicesAreEqualWithFunc[T any](a, b []T, equals func(a, b T) bool) bool {
307411
if len(a) != len(b) {
308412
return false
@@ -322,3 +426,56 @@ func slicesAreEqualWithFunc[T any](a, b []T, equals func(a, b T) bool) bool {
322426

323427
return true
324428
}
429+
430+
// Please note that this call is destructive. Make a clone of your graph before calling if you
431+
// wish to preserve the graph.
432+
func verifyTopologicalSort[K comparable, T any](graph Graph[K, T], order []K) error {
433+
434+
adjacencyMap, err := graph.AdjacencyMap()
435+
if err != nil {
436+
return fmt.Errorf("failed to get adjacency map: %v", err)
437+
}
438+
439+
for i := range order {
440+
441+
for _, edge := range adjacencyMap[order[i]] {
442+
err = graph.RemoveEdge(edge.Source, edge.Target)
443+
if err != nil {
444+
return fmt.Errorf("failed to remove edge: %v -> %v : %v", edge.Source, edge.Target, err)
445+
}
446+
}
447+
448+
err = graph.RemoveVertex(order[i])
449+
if err != nil {
450+
return fmt.Errorf("failed to remove vertex: %v at index %d: %v", order[i], i, err)
451+
}
452+
}
453+
454+
return nil
455+
}
456+
457+
// randomizes the ordering of the edges and vertices to help ferret out any potential bugs
458+
// related to ordering
459+
func buildGraph[K comparable, T any](g *Graph[K, T], vertices []T, edges []Edge[K]) error {
460+
461+
if g == nil {
462+
return fmt.Errorf("graph must be initialized")
463+
}
464+
465+
rand.Seed(time.Now().UnixNano())
466+
rand.Shuffle(len(vertices), func(i, j int) { vertices[i], vertices[j] = vertices[j], vertices[i] })
467+
468+
for _, vertex := range vertices {
469+
_ = (*g).AddVertex(vertex)
470+
}
471+
472+
rand.Shuffle(len(edges), func(i, j int) { edges[i], edges[j] = edges[j], edges[i] })
473+
474+
for _, edge := range edges {
475+
if err := (*g).AddEdge(edge.Source, edge.Target, EdgeWeight(edge.Properties.Weight)); err != nil {
476+
return err
477+
}
478+
}
479+
480+
return nil
481+
}

0 commit comments

Comments
 (0)