diff --git a/directed.go b/directed.go index 95d67f5..0997d5e 100644 --- a/directed.go +++ b/directed.go @@ -131,25 +131,15 @@ func (d *directed[K, T]) AddEdgesFrom(g Graph[K, T]) error { return nil } -func (d *directed[K, T]) Edge(sourceHash, targetHash K) (Edge[T], error) { +func (d *directed[K, T]) Edge(sourceHash, targetHash K) (Edge[K], error) { edge, err := d.store.Edge(sourceHash, targetHash) if err != nil { - return Edge[T]{}, err + return Edge[K]{}, err } - sourceVertex, _, err := d.store.Vertex(sourceHash) - if err != nil { - return Edge[T]{}, err - } - - targetVertex, _, err := d.store.Vertex(targetHash) - if err != nil { - return Edge[T]{}, err - } - - return Edge[T]{ - Source: sourceVertex, - Target: targetVertex, + return Edge[K]{ + Source: sourceHash, + Target: targetHash, Properties: EdgeProperties{ Weight: edge.Properties.Weight, Attributes: edge.Properties.Attributes, @@ -286,13 +276,10 @@ func (d *directed[K, T]) Size() (int, error) { return size, nil } -func (d *directed[K, T]) edgesAreEqual(a, b Edge[T]) bool { - aSourceHash := d.hash(a.Source) - aTargetHash := d.hash(a.Target) - bSourceHash := d.hash(b.Source) - bTargetHash := d.hash(b.Target) - - return aSourceHash == bSourceHash && aTargetHash == bTargetHash +// This only tells you the source and target are the same, it does not +// tell you that the edge properties are the same as well. +func (d *directed[K, T]) edgesAreEqual(a, b Edge[K]) bool { + return a.Source == b.Source && a.Target == b.Target } func (d *directed[K, T]) createsCycle(source, target K) (bool, error) { diff --git a/graph.go b/graph.go index 9376eb5..e731294 100644 --- a/graph.go +++ b/graph.go @@ -124,7 +124,7 @@ type Graph[K comparable, T any] interface { // Edge returns the edge joining two given vertices or ErrEdgeNotFound if // the edge doesn't exist. In an undirected graph, an edge with swapped // source and target vertices does match. - Edge(sourceHash, targetHash K) (Edge[T], error) + Edge(sourceHash, targetHash K) (Edge[K], error) // Edges returns a slice of all edges in the graph. These edges are of type // Edge[K] and hence will contain the vertex hashes, not the vertex values. @@ -213,9 +213,9 @@ type Graph[K comparable, T any] interface { // Edge represents an edge that joins two vertices. Even though these edges are // always referred to as source and target, whether the graph is directed or not // is determined by its traits. -type Edge[T any] struct { - Source T - Target T +type Edge[K comparable] struct { + Source K + Target K Properties EdgeProperties } @@ -232,6 +232,21 @@ type EdgeProperties struct { Data any } +func (p *EdgeProperties) Clone() EdgeProperties { + + ep := EdgeProperties{ + Attributes: make(map[string]string), + Weight: p.Weight, + Data: p.Data, + } + + for k, v := range p.Attributes { + ep.Attributes[k] = v + } + + return ep +} + // Hash is a hashing function that takes a vertex of type T and returns a hash // value of type K. // diff --git a/store.go b/store.go index 30dc0eb..61814f5 100644 --- a/store.go +++ b/store.go @@ -71,16 +71,16 @@ type memoryStore[K comparable, T any] struct { // outEdges and inEdges store all outgoing and ingoing edges for all vertices. For O(1) access, // these edges themselves are stored in maps whose keys are the hashes of the target vertices. - outEdges map[K]map[K]Edge[K] // source -> target - inEdges map[K]map[K]Edge[K] // target -> source + outEdges map[K]map[K]*EdgeProperties // source -> target + inEdges map[K]map[K]*EdgeProperties // target -> source } func newMemoryStore[K comparable, T any]() Store[K, T] { return &memoryStore[K, T]{ vertices: make(map[K]T), vertexProperties: make(map[K]VertexProperties), - outEdges: make(map[K]map[K]Edge[K]), - inEdges: make(map[K]map[K]Edge[K]), + outEdges: make(map[K]map[K]*EdgeProperties), + inEdges: make(map[K]map[K]*EdgeProperties), } } @@ -143,16 +143,16 @@ func (s *memoryStore[K, T]) RemoveVertex(k K) error { if len(edges) > 0 { return ErrVertexHasEdges } - delete(s.inEdges, k) } if edges, ok := s.outEdges[k]; ok { if len(edges) > 0 { return ErrVertexHasEdges } - delete(s.outEdges, k) } + delete(s.inEdges, k) + delete(s.outEdges, k) delete(s.vertices, k) delete(s.vertexProperties, k) @@ -163,31 +163,42 @@ func (s *memoryStore[K, T]) AddEdge(sourceHash, targetHash K, edge Edge[K]) erro s.lock.Lock() defer s.lock.Unlock() + edgeProperties := edge.Properties.Clone() + if _, ok := s.outEdges[sourceHash]; !ok { - s.outEdges[sourceHash] = make(map[K]Edge[K]) + s.outEdges[sourceHash] = make(map[K]*EdgeProperties) } - s.outEdges[sourceHash][targetHash] = edge + s.outEdges[sourceHash][targetHash] = &edgeProperties if _, ok := s.inEdges[targetHash]; !ok { - s.inEdges[targetHash] = make(map[K]Edge[K]) + s.inEdges[targetHash] = make(map[K]*EdgeProperties) } - s.inEdges[targetHash][sourceHash] = edge + s.inEdges[targetHash][sourceHash] = &edgeProperties return nil } func (s *memoryStore[K, T]) UpdateEdge(sourceHash, targetHash K, edge Edge[K]) error { - if _, err := s.Edge(sourceHash, targetHash); err != nil { - return err - } + + edgeProperties := edge.Properties.Clone() s.lock.Lock() defer s.lock.Unlock() - s.outEdges[sourceHash][targetHash] = edge - s.inEdges[targetHash][sourceHash] = edge + sourceEdges, ok := s.outEdges[sourceHash] + if !ok { + return ErrEdgeNotFound + } + + _, ok = sourceEdges[targetHash] + if !ok { + return ErrEdgeNotFound + } + + s.outEdges[sourceHash][targetHash] = &edgeProperties + s.inEdges[targetHash][sourceHash] = &edgeProperties return nil } @@ -198,6 +209,7 @@ func (s *memoryStore[K, T]) RemoveEdge(sourceHash, targetHash K) error { delete(s.inEdges[targetHash], sourceHash) delete(s.outEdges[sourceHash], targetHash) + return nil } @@ -210,11 +222,17 @@ func (s *memoryStore[K, T]) Edge(sourceHash, targetHash K) (Edge[K], error) { return Edge[K]{}, ErrEdgeNotFound } - edge, ok := sourceEdges[targetHash] + edgeProperties, ok := sourceEdges[targetHash] if !ok { return Edge[K]{}, ErrEdgeNotFound } + edge := Edge[K]{ + Source: sourceHash, + Target: targetHash, + Properties: edgeProperties.Clone(), + } + return edge, nil } @@ -223,8 +241,15 @@ func (s *memoryStore[K, T]) ListEdges() ([]Edge[K], error) { defer s.lock.RUnlock() res := make([]Edge[K], 0) - for _, edges := range s.outEdges { - for _, edge := range edges { + for sourceKey, edges := range s.outEdges { + for targetKey, edgeProperties := range edges { + + edge := Edge[K]{ + Source: sourceKey, + Target: targetKey, + Properties: edgeProperties.Clone(), + } + res = append(res, edge) } } diff --git a/undirected.go b/undirected.go index 37d320c..6533e3d 100644 --- a/undirected.go +++ b/undirected.go @@ -135,32 +135,19 @@ func (u *undirected[K, T]) AddVerticesFrom(g Graph[K, T]) error { return nil } -func (u *undirected[K, T]) Edge(sourceHash, targetHash K) (Edge[T], error) { - // In an undirected graph, since multigraphs aren't supported, the edge AB - // is the same as BA. Therefore, if source[target] cannot be found, this - // function also looks for target[source]. - edge, err := u.store.Edge(sourceHash, targetHash) - if errors.Is(err, ErrEdgeNotFound) { - edge, err = u.store.Edge(targetHash, sourceHash) - } - - if err != nil { - return Edge[T]{}, err - } +func (u *undirected[K, T]) Edge(sourceHash, targetHash K) (Edge[K], error) { - sourceVertex, _, err := u.store.Vertex(sourceHash) - if err != nil { - return Edge[T]{}, err - } + // no need to do a reverse lookup because addEdge() is already writing it + // in both directions to the store. + edge, err := u.store.Edge(sourceHash, targetHash) - targetVertex, _, err := u.store.Vertex(targetHash) if err != nil { - return Edge[T]{}, err + return Edge[K]{}, err } - return Edge[T]{ - Source: sourceVertex, - Target: targetVertex, + return Edge[K]{ + Source: sourceHash, + Target: targetHash, Properties: EdgeProperties{ Weight: edge.Properties.Weight, Attributes: edge.Properties.Attributes, @@ -327,18 +314,14 @@ func (u *undirected[K, T]) Size() (int, error) { return size / 2, nil } -func (u *undirected[K, T]) edgesAreEqual(a, b Edge[T]) bool { - aSourceHash := u.hash(a.Source) - aTargetHash := u.hash(a.Target) - bSourceHash := u.hash(b.Source) - bTargetHash := u.hash(b.Target) +func (u *undirected[K, T]) edgesAreEqual(a, b Edge[K]) bool { - if aSourceHash == bSourceHash && aTargetHash == bTargetHash { + if a.Source == b.Source && a.Target == b.Target { return true } if !u.traits.IsDirected { - return aSourceHash == bTargetHash && aTargetHash == bSourceHash + return a.Source == b.Target && a.Target == b.Source } return false diff --git a/undirected_test.go b/undirected_test.go index c304eb3..dfa8fb3 100644 --- a/undirected_test.go +++ b/undirected_test.go @@ -384,11 +384,8 @@ func TestUndirected_AddEdge(t *testing.T) { } for _, expectedEdge := range test.expectedEdges { - sourceHash := graph.hash(expectedEdge.Source) - targetHash := graph.hash(expectedEdge.Target) - - edge, ok := graph.store.(*memoryStore[int, int]).outEdges[sourceHash][targetHash] - if !ok { + edge, err := graph.Edge(expectedEdge.Source, expectedEdge.Target) + if err != nil { t.Fatalf("%s: edge with source %v and target %v not found", name, expectedEdge.Source, expectedEdge.Target) }