diff --git a/directed.go b/directed.go index 95d67f5..92834c4 100644 --- a/directed.go +++ b/directed.go @@ -1,7 +1,6 @@ package graph import ( - "errors" "fmt" ) @@ -76,20 +75,6 @@ func (d *directed[K, T]) RemoveVertex(hash K) error { } func (d *directed[K, T]) AddEdge(sourceHash, targetHash K, options ...func(*EdgeProperties)) error { - _, _, err := d.store.Vertex(sourceHash) - if err != nil { - return fmt.Errorf("source vertex %v: %w", sourceHash, err) - } - - _, _, err = d.store.Vertex(targetHash) - if err != nil { - return fmt.Errorf("target vertex %v: %w", targetHash, err) - } - - if _, err := d.Edge(sourceHash, targetHash); !errors.Is(err, ErrEdgeNotFound) { - return ErrEdgeAlreadyExists - } - // If the user opted in to preventing cycles, run a cycle check. if d.traits.PreventCycles { createsCycle, err := d.createsCycle(sourceHash, targetHash) @@ -97,7 +82,7 @@ func (d *directed[K, T]) AddEdge(sourceHash, targetHash K, options ...func(*Edge return fmt.Errorf("check for cycles: %w", err) } if createsCycle { - return ErrEdgeCreatesCycle + return &EdgeCausesCycleError[K]{Source: sourceHash, Target: targetHash} } } @@ -176,10 +161,6 @@ func (d *directed[K, T]) UpdateEdge(source, target K, options ...func(properties } func (d *directed[K, T]) RemoveEdge(source, target K) error { - if _, err := d.Edge(source, target); err != nil { - return err - } - if err := d.store.RemoveEdge(source, target); err != nil { return fmt.Errorf("failed to remove edge from %v to %v: %w", source, target, err) } @@ -273,17 +254,11 @@ func (d *directed[K, T]) Order() (int, error) { } func (d *directed[K, T]) Size() (int, error) { - size := 0 - outEdges, err := d.AdjacencyMap() + edges, err := d.store.ListEdges() if err != nil { - return 0, fmt.Errorf("failed to get adjacency map: %w", err) + return 0, fmt.Errorf("failed to list edges: %w", err) } - - for _, outEdges := range outEdges { - size += len(outEdges) - } - - return size, nil + return len(edges), nil } func (d *directed[K, T]) edgesAreEqual(a, b Edge[T]) bool { diff --git a/directed_test.go b/directed_test.go index 34a9c8d..fb34d01 100644 --- a/directed_test.go +++ b/directed_test.go @@ -886,7 +886,7 @@ func TestDirected_RemoveEdge(t *testing.T) { removeEdges: []Edge[int]{ {Source: 2, Target: 3}, }, - expectedError: ErrEdgeNotFound, + // Expect no error because memoryStore doesn't error }, } @@ -909,7 +909,7 @@ func TestDirected_RemoveEdge(t *testing.T) { } // After removing the edge, verify that it can't be retrieved using // Edge anymore. - if _, err := graph.Edge(removeEdge.Source, removeEdge.Target); err != ErrEdgeNotFound { + if _, err := graph.Edge(removeEdge.Source, removeEdge.Target); !errors.Is(err, ErrEdgeNotFound) { t.Fatalf("%s: error expectancy doesn't match: expected %v, got %v", name, ErrEdgeNotFound, err) } } @@ -1267,6 +1267,8 @@ func TestDirected_addEdge(t *testing.T) { graph := newDirected(IntHash, &Traits{}, newMemoryStore[int, int]()) for _, edge := range test.edges { + _ = graph.AddVertex(edge.Source) + _ = graph.AddVertex(edge.Target) sourceHash := graph.hash(edge.Source) TargetHash := graph.hash(edge.Target) err := graph.addEdge(sourceHash, TargetHash, edge) diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..a9f124d --- /dev/null +++ b/errors.go @@ -0,0 +1,74 @@ +package graph + +import ( + "errors" + "fmt" +) + +type ( + VertexAlreadyExistsError[K comparable, T any] struct { + Key K + ExistingValue T + } + + VertexNotFoundError[K comparable] struct { + Key K + } + + EdgeAlreadyExistsError[K comparable] struct { + Source, Target K + } + + EdgeNotFoundError[K comparable] struct { + Source, Target K + } + + VertexHasEdgesError[K comparable] struct { + Key K + Count int + } + + EdgeCausesCycleError[K comparable] struct { + Source, Target K + } +) + +func (e *VertexAlreadyExistsError[K, T]) Error() string { + return fmt.Sprintf("vertex %v already exists with value %v", e.Key, e.ExistingValue) +} + +func (e *VertexNotFoundError[K]) Error() string { + return fmt.Sprintf("vertex %v not found", e.Key) +} + +func (e *EdgeAlreadyExistsError[K]) Error() string { + return fmt.Sprintf("edge %v - %v already exists", e.Source, e.Target) +} + +func (e *EdgeNotFoundError[K]) Error() string { + return fmt.Sprintf("edge %v - %v not found", e.Source, e.Target) +} + +func (e *VertexHasEdgesError[K]) Error() string { + return fmt.Sprintf("vertex %v has %d edges", e.Key, e.Count) +} + +func (e *EdgeCausesCycleError[K]) Error() string { + return fmt.Sprintf("edge %v - %v would cause a cycle", e.Source, e.Target) +} + +var ( + ErrVertexNotFound = errors.New("vertex not found") + ErrVertexAlreadyExists = errors.New("vertex already exists") + ErrEdgeNotFound = errors.New("edge not found") + ErrEdgeAlreadyExists = errors.New("edge already exists") + ErrEdgeCreatesCycle = errors.New("edge would create a cycle") + ErrVertexHasEdges = errors.New("vertex has edges") +) + +func (e *VertexAlreadyExistsError[K, T]) Unwrap() error { return ErrVertexAlreadyExists } +func (e *VertexNotFoundError[K]) Unwrap() error { return ErrVertexNotFound } +func (e *EdgeAlreadyExistsError[K]) Unwrap() error { return ErrEdgeAlreadyExists } +func (e *EdgeNotFoundError[K]) Unwrap() error { return ErrEdgeNotFound } +func (e *VertexHasEdgesError[K]) Unwrap() error { return ErrVertexHasEdges } +func (e *EdgeCausesCycleError[K]) Unwrap() error { return ErrEdgeCreatesCycle } diff --git a/graph.go b/graph.go index 9376eb5..f593ec2 100644 --- a/graph.go +++ b/graph.go @@ -49,17 +49,6 @@ // For detailed usage examples, take a look at the README. package graph -import "errors" - -var ( - ErrVertexNotFound = errors.New("vertex not found") - ErrVertexAlreadyExists = errors.New("vertex already exists") - ErrEdgeNotFound = errors.New("edge not found") - ErrEdgeAlreadyExists = errors.New("edge already exists") - ErrEdgeCreatesCycle = errors.New("edge would create a cycle") - ErrVertexHasEdges = errors.New("vertex has edges") -) - // Graph represents a generic graph data structure consisting of vertices of // type T identified by a hash of type K. type Graph[K comparable, T any] interface { diff --git a/paths.go b/paths.go index 130547d..1f546dd 100644 --- a/paths.go +++ b/paths.go @@ -15,11 +15,11 @@ var ErrTargetNotReachable = errors.New("target vertex not reachable from source" // of the source vertex. In order to determine this, CreatesCycle runs a DFS. func CreatesCycle[K comparable, T any](g Graph[K, T], source, target K) (bool, error) { if _, err := g.Vertex(source); err != nil { - return false, fmt.Errorf("could not get vertex with hash %v: %w", source, err) + return false, fmt.Errorf("could not get source vertex: %w", err) } if _, err := g.Vertex(target); err != nil { - return false, fmt.Errorf("could not get vertex with hash %v: %w", target, err) + return false, fmt.Errorf("could not get target vertex: %w", err) } if source == target { diff --git a/store.go b/store.go index 30dc0eb..1bbe245 100644 --- a/store.go +++ b/store.go @@ -88,8 +88,11 @@ func (s *memoryStore[K, T]) AddVertex(k K, t T, p VertexProperties) error { s.lock.Lock() defer s.lock.Unlock() - if _, ok := s.vertices[k]; ok { - return ErrVertexAlreadyExists + if existing, ok := s.vertices[k]; ok { + return &VertexAlreadyExistsError[K, T]{ + Key: k, + ExistingValue: existing, + } } s.vertices[k] = t @@ -120,10 +123,15 @@ func (s *memoryStore[K, T]) VertexCount() (int, error) { func (s *memoryStore[K, T]) Vertex(k K) (T, VertexProperties, error) { s.lock.RLock() defer s.lock.RUnlock() + return s.vertexWithLock(k) +} +// vertexWithLock returns the vertex and vertex properties - the caller must be holding at least a +// read-level lock. +func (s *memoryStore[K, T]) vertexWithLock(k K) (T, VertexProperties, error) { v, ok := s.vertices[k] if !ok { - return v, VertexProperties{}, ErrVertexNotFound + return v, VertexProperties{}, &VertexNotFoundError[K]{Key: k} } p := s.vertexProperties[k] @@ -136,19 +144,19 @@ func (s *memoryStore[K, T]) RemoveVertex(k K) error { defer s.lock.RUnlock() if _, ok := s.vertices[k]; !ok { - return ErrVertexNotFound + return &VertexNotFoundError[K]{Key: k} } if edges, ok := s.inEdges[k]; ok { - if len(edges) > 0 { - return ErrVertexHasEdges + if count := len(edges); count > 0 { + return &VertexHasEdgesError[K]{Key: k, Count: count} } delete(s.inEdges, k) } if edges, ok := s.outEdges[k]; ok { - if len(edges) > 0 { - return ErrVertexHasEdges + if count := len(edges); count > 0 { + return &VertexHasEdgesError[K]{Key: k, Count: count} } delete(s.outEdges, k) } @@ -163,29 +171,45 @@ func (s *memoryStore[K, T]) AddEdge(sourceHash, targetHash K, edge Edge[K]) erro s.lock.Lock() defer s.lock.Unlock() + if _, _, err := s.vertexWithLock(sourceHash); err != nil { + return fmt.Errorf("could not get source vertex: %w", &VertexNotFoundError[K]{Key: sourceHash}) + } + if _, ok := s.outEdges[sourceHash]; !ok { s.outEdges[sourceHash] = make(map[K]Edge[K]) } + if _, ok := s.outEdges[sourceHash][targetHash]; ok { + return &EdgeAlreadyExistsError[K]{Source: sourceHash, Target: targetHash} + } + s.outEdges[sourceHash][targetHash] = edge + if _, _, err := s.vertexWithLock(targetHash); err != nil { + return fmt.Errorf("could not get target vertex: %w", &VertexNotFoundError[K]{Key: targetHash}) + } + if _, ok := s.inEdges[targetHash]; !ok { s.inEdges[targetHash] = make(map[K]Edge[K]) } + if _, ok := s.inEdges[targetHash][sourceHash]; ok { + return &EdgeAlreadyExistsError[K]{Source: sourceHash, Target: targetHash} + } + s.inEdges[targetHash][sourceHash] = edge return nil } func (s *memoryStore[K, T]) UpdateEdge(sourceHash, targetHash K, edge Edge[K]) error { - if _, err := s.Edge(sourceHash, targetHash); err != nil { - return err - } - s.lock.Lock() defer s.lock.Unlock() + if _, err := s.edgeWithLock(sourceHash, targetHash); err != nil { + return err + } + s.outEdges[sourceHash][targetHash] = edge s.inEdges[targetHash][sourceHash] = edge @@ -204,15 +228,19 @@ func (s *memoryStore[K, T]) RemoveEdge(sourceHash, targetHash K) error { func (s *memoryStore[K, T]) Edge(sourceHash, targetHash K) (Edge[K], error) { s.lock.RLock() defer s.lock.RUnlock() + return s.edgeWithLock(sourceHash, targetHash) +} +// edgeWithLock returns the edge - the caller must be holding at least a read-level lock. +func (s *memoryStore[K, T]) edgeWithLock(sourceHash, targetHash K) (Edge[K], error) { sourceEdges, ok := s.outEdges[sourceHash] if !ok { - return Edge[K]{}, ErrEdgeNotFound + return Edge[K]{}, &EdgeNotFoundError[K]{Source: sourceHash, Target: targetHash} } edge, ok := sourceEdges[targetHash] if !ok { - return Edge[K]{}, ErrEdgeNotFound + return Edge[K]{}, &EdgeNotFoundError[K]{Source: sourceHash, Target: targetHash} } return edge, nil @@ -237,12 +265,15 @@ func (s *memoryStore[K, T]) ListEdges() ([]Edge[K], error) { // Because CreatesCycle doesn't need to modify the PredecessorMap, we can use // inEdges instead to compute the same thing without creating any copies. func (s *memoryStore[K, T]) CreatesCycle(source, target K) (bool, error) { - if _, _, err := s.Vertex(source); err != nil { - return false, fmt.Errorf("could not get vertex with hash %v: %w", source, err) + s.lock.RLock() + defer s.lock.RUnlock() + + if _, _, err := s.vertexWithLock(source); err != nil { + return false, fmt.Errorf("could not get source vertex: %w", err) } - if _, _, err := s.Vertex(target); err != nil { - return false, fmt.Errorf("could not get vertex with hash %v: %w", target, err) + if _, _, err := s.vertexWithLock(target); err != nil { + return false, fmt.Errorf("could not get target vertex: %w", err) } if source == target { diff --git a/undirected.go b/undirected.go index 37d320c..a263ef5 100644 --- a/undirected.go +++ b/undirected.go @@ -57,27 +57,14 @@ func (u *undirected[K, T]) RemoveVertex(hash K) error { } func (u *undirected[K, T]) AddEdge(sourceHash, targetHash K, options ...func(*EdgeProperties)) error { - if _, _, err := u.store.Vertex(sourceHash); err != nil { - return fmt.Errorf("could not find source vertex with hash %v: %w", sourceHash, err) - } - - if _, _, err := u.store.Vertex(targetHash); err != nil { - return fmt.Errorf("could not find target vertex with hash %v: %w", targetHash, err) - } - - //nolint:govet // False positive. - if _, err := u.Edge(sourceHash, targetHash); !errors.Is(err, ErrEdgeNotFound) { - return ErrEdgeAlreadyExists - } - // If the user opted in to preventing cycles, run a cycle check. if u.traits.PreventCycles { - createsCycle, err := CreatesCycle[K, T](u, sourceHash, targetHash) + createsCycle, err := u.createsCycle(sourceHash, targetHash) if err != nil { return fmt.Errorf("check for cycles: %w", err) } if createsCycle { - return ErrEdgeCreatesCycle + return &EdgeCausesCycleError[K]{Source: sourceHash, Target: targetHash} } } @@ -93,11 +80,7 @@ func (u *undirected[K, T]) AddEdge(sourceHash, targetHash K, options ...func(*Ed option(&edge.Properties) } - if err := u.addEdge(sourceHash, targetHash, edge); err != nil { - return fmt.Errorf("failed to add edge: %w", err) - } - - return nil + return u.addEdge(sourceHash, targetHash, edge) } func (u *undirected[K, T]) AddEdgesFrom(g Graph[K, T]) error { @@ -238,10 +221,6 @@ func (u *undirected[K, T]) UpdateEdge(source, target K, options ...func(properti } func (u *undirected[K, T]) RemoveEdge(source, target K) error { - if _, err := u.Edge(source, target); err != nil { - return err - } - if err := u.store.RemoveEdge(source, target); err != nil { return fmt.Errorf("failed to remove edge from %v to %v: %w", source, target, err) } @@ -311,20 +290,13 @@ func (u *undirected[K, T]) Order() (int, error) { } func (u *undirected[K, T]) Size() (int, error) { - size := 0 - - outEdges, err := u.AdjacencyMap() + edges, err := u.store.ListEdges() if err != nil { - return 0, fmt.Errorf("failed to get adjacency map: %w", err) - } - - for _, outEdges := range outEdges { - size += len(outEdges) + return 0, fmt.Errorf("failed to list edges: %w", err) } - // Divide by 2 since every add edge operation on undirected graph is counted // twice. - return size / 2, nil + return len(edges) / 2, nil } func (u *undirected[K, T]) edgesAreEqual(a, b Edge[T]) bool { @@ -343,6 +315,17 @@ func (u *undirected[K, T]) edgesAreEqual(a, b Edge[T]) bool { return false } +func (u *undirected[K, T]) createsCycle(source, target K) (bool, error) { + // If the underlying store implements CreatesCycle, use that fast path. + if cc, ok := u.store.(interface { + CreatesCycle(source, target K) (bool, error) + }); ok { + return cc.CreatesCycle(source, target) + } + + // Slow path. + return CreatesCycle(Graph[K, T](u), source, target) +} func (u *undirected[K, T]) addEdge(sourceHash, targetHash K, edge Edge[K]) error { err := u.store.AddEdge(sourceHash, targetHash, edge) diff --git a/undirected_test.go b/undirected_test.go index c304eb3..3f09ea9 100644 --- a/undirected_test.go +++ b/undirected_test.go @@ -77,7 +77,7 @@ func TestUndirected_AddVertex(t *testing.T) { } } - if err != test.finallyExpectedError { + if !errors.Is(err, test.finallyExpectedError) { t.Errorf("%s: error expectancy doesn't match: expected %v, got %v", name, test.finallyExpectedError, err) } @@ -238,7 +238,7 @@ func TestUndirected_Vertex(t *testing.T) { vertex, err := graph.Vertex(test.vertex) - if err != test.expectedError { + if !errors.Is(err, test.expectedError) { t.Errorf("%s: error expectancy doesn't match: expected %v, got %v", name, test.expectedError, err) } @@ -880,7 +880,7 @@ func TestUndirected_RemoveEdge(t *testing.T) { removeEdges: []Edge[int]{ {Source: 2, Target: 3}, }, - expectedError: ErrEdgeNotFound, + // Expect no error because memoryStore doesn't error }, } @@ -903,7 +903,7 @@ func TestUndirected_RemoveEdge(t *testing.T) { } // After removing the edge, verify that it can't be retrieved using // Edge anymore. - if _, err := graph.Edge(removeEdge.Source, removeEdge.Target); err != ErrEdgeNotFound { + if _, err := graph.Edge(removeEdge.Source, removeEdge.Target); !errors.Is(err, ErrEdgeNotFound) { t.Fatalf("%s: error expectancy doesn't match: expected %v, got %v", name, ErrEdgeNotFound, err) } } @@ -1284,6 +1284,8 @@ func TestUndirected_addEdge(t *testing.T) { graph := newUndirected(IntHash, &Traits{}, newMemoryStore[int, int]()) for _, edge := range test.edges { + _ = graph.AddVertex(edge.Source) + _ = graph.AddVertex(edge.Target) sourceHash := graph.hash(edge.Source) TargetHash := graph.hash(edge.Target) err := graph.addEdge(sourceHash, TargetHash, edge)