|
4 | 4 | // Package graph provides functionality for directed graphs. |
5 | 5 | package graph |
6 | 6 |
|
| 7 | +import ( |
| 8 | + "context" |
| 9 | + "sync" |
| 10 | + |
| 11 | + "golang.org/x/sync/errgroup" |
| 12 | +) |
| 13 | + |
7 | 14 | // vertexStatus denotes the visiting status of a vertex when running DFS in a graph. |
8 | 15 | type vertexStatus int |
9 | 16 |
|
@@ -215,3 +222,228 @@ func TopologicalOrder[V comparable](digraph *Graph[V]) (*TopologicalSorter[V], e |
215 | 222 | topo.traverse(digraph) |
216 | 223 | return topo, nil |
217 | 224 | } |
| 225 | + |
| 226 | +// LabeledGraph extends a generic Graph by associating a label (or status) with each vertex. |
| 227 | +// It is concurrency-safe, utilizing a mutex lock for synchronized access. |
| 228 | +type LabeledGraph[V comparable] struct { |
| 229 | + *Graph[V] |
| 230 | + status map[V]string |
| 231 | + lock sync.Mutex |
| 232 | +} |
| 233 | + |
| 234 | +// NewLabeledGraph initializes a LabeledGraph with specified vertices and optional configurations. |
| 235 | +// It creates a base Graph with the vertices and applies any LabeledGraphOption to configure additional properties. |
| 236 | +func NewLabeledGraph[V comparable](vertices []V, opts ...LabeledGraphOption[V]) *LabeledGraph[V] { |
| 237 | + g := New(vertices...) |
| 238 | + lg := &LabeledGraph[V]{ |
| 239 | + Graph: g, |
| 240 | + status: make(map[V]string), |
| 241 | + } |
| 242 | + for _, opt := range opts { |
| 243 | + opt(lg) |
| 244 | + } |
| 245 | + return lg |
| 246 | +} |
| 247 | + |
| 248 | +// LabeledGraphOption allows you to initialize Graph with additional properties. |
| 249 | +type LabeledGraphOption[V comparable] func(g *LabeledGraph[V]) |
| 250 | + |
| 251 | +// WithStatus sets the status of each vertex in the Graph. |
| 252 | +func WithStatus[V comparable](status string) func(g *LabeledGraph[V]) { |
| 253 | + return func(g *LabeledGraph[V]) { |
| 254 | + g.status = make(map[V]string) |
| 255 | + for vertex := range g.vertices { |
| 256 | + g.status[vertex] = status |
| 257 | + } |
| 258 | + } |
| 259 | +} |
| 260 | + |
| 261 | +// updateStatus updates the status of a vertex. |
| 262 | +func (lg *LabeledGraph[V]) updateStatus(vertex V, status string) { |
| 263 | + lg.lock.Lock() |
| 264 | + defer lg.lock.Unlock() |
| 265 | + lg.status[vertex] = status |
| 266 | +} |
| 267 | + |
| 268 | +// getStatus gets the status of a vertex. |
| 269 | +func (lg *LabeledGraph[V]) getStatus(vertex V) string { |
| 270 | + lg.lock.Lock() |
| 271 | + defer lg.lock.Unlock() |
| 272 | + return lg.status[vertex] |
| 273 | +} |
| 274 | + |
| 275 | +// getLeaves returns the leaves of a given vertex. |
| 276 | +func (lg *LabeledGraph[V]) leaves() []V { |
| 277 | + lg.lock.Lock() |
| 278 | + defer lg.lock.Unlock() |
| 279 | + var leaves []V |
| 280 | + for vtx := range lg.vertices { |
| 281 | + if len(lg.vertices[vtx]) == 0 { |
| 282 | + leaves = append(leaves, vtx) |
| 283 | + } |
| 284 | + } |
| 285 | + return leaves |
| 286 | +} |
| 287 | + |
| 288 | +// getParents returns the parent vertices (incoming edges) of vertex. |
| 289 | +func (lg *LabeledGraph[V]) parents(vtx V) []V { |
| 290 | + lg.lock.Lock() |
| 291 | + defer lg.lock.Unlock() |
| 292 | + var parents []V |
| 293 | + for v, neighbors := range lg.vertices { |
| 294 | + if neighbors[vtx] { |
| 295 | + parents = append(parents, v) |
| 296 | + } |
| 297 | + } |
| 298 | + return parents |
| 299 | +} |
| 300 | + |
| 301 | +// getChildren returns the child vertices (outgoing edges) of vertex. |
| 302 | +func (lg *LabeledGraph[V]) children(vtx V) []V { |
| 303 | + lg.lock.Lock() |
| 304 | + defer lg.lock.Unlock() |
| 305 | + return lg.Neighbors(vtx) |
| 306 | +} |
| 307 | + |
| 308 | +// filterParents filters parents based on the vertex status. |
| 309 | +func (lg *LabeledGraph[V]) filterParents(vtx V, status string) []V { |
| 310 | + parents := lg.parents(vtx) |
| 311 | + var filtered []V |
| 312 | + for _, parent := range parents { |
| 313 | + if lg.getStatus(parent) == status { |
| 314 | + filtered = append(filtered, parent) |
| 315 | + } |
| 316 | + } |
| 317 | + return filtered |
| 318 | +} |
| 319 | + |
| 320 | +// filterChildren filters children based on the vertex status. |
| 321 | +func (lg *LabeledGraph[V]) filterChildren(vtx V, status string) []V { |
| 322 | + children := lg.children(vtx) |
| 323 | + var filtered []V |
| 324 | + for _, child := range children { |
| 325 | + if lg.getStatus(child) == status { |
| 326 | + filtered = append(filtered, child) |
| 327 | + } |
| 328 | + } |
| 329 | + return filtered |
| 330 | +} |
| 331 | + |
| 332 | +/* |
| 333 | +UpwardTraversal performs an upward traversal on the graph starting from leaves (nodes with no children) |
| 334 | +and moving towards root nodes (nodes with children). |
| 335 | +It applies the specified process function to each vertex in the graph, skipping vertices with the |
| 336 | +"adjacentVertexSkipStatus" status, and continuing traversal until reaching vertices with the "requiredVertexStatus" status. |
| 337 | +The traversal is concurrent and may process vertices in parallel. |
| 338 | +Returns an error if the traversal encounters any issues, or nil if successful. |
| 339 | +*/ |
| 340 | +func (lg *LabeledGraph[V]) UpwardTraversal(ctx context.Context, processVertexFunc func(context.Context, V) error, nextVertexSkipStatus, requiredVertexStatus string) error { |
| 341 | + traversal := &graphTraversal[V]{ |
| 342 | + mu: sync.Mutex{}, |
| 343 | + seen: make(map[V]struct{}), |
| 344 | + findStartVertices: func(lg *LabeledGraph[V]) []V { return lg.leaves() }, |
| 345 | + findNextVertices: func(lg *LabeledGraph[V], v V) []V { return lg.parents(v) }, |
| 346 | + filterPreviousVerticesByStatus: func(g *LabeledGraph[V], v V, status string) []V { return g.filterChildren(v, status) }, |
| 347 | + requiredVertexStatus: requiredVertexStatus, |
| 348 | + nextVertexSkipStatus: nextVertexSkipStatus, |
| 349 | + processVertex: processVertexFunc, |
| 350 | + } |
| 351 | + return traversal.execute(ctx, lg) |
| 352 | +} |
| 353 | + |
| 354 | +/* |
| 355 | +DownwardTraversal performs a downward traversal on the graph starting from root nodes (nodes with no parents) |
| 356 | +and moving towards leaf nodes (nodes with parents). It applies the specified process function to each |
| 357 | +vertex in the graph, skipping vertices with the "adjacentVertexSkipStatus" status, and continuing traversal |
| 358 | +until reaching vertices with the "requiredVertexStatus" status. |
| 359 | +The traversal is concurrent and may process vertices in parallel. |
| 360 | +Returns an error if the traversal encounters any issues. |
| 361 | +*/ |
| 362 | +func (lg *LabeledGraph[V]) DownwardTraversal(ctx context.Context, processVertexFunc func(context.Context, V) error, adjacentVertexSkipStatus, requiredVertexStatus string) error { |
| 363 | + traversal := &graphTraversal[V]{ |
| 364 | + mu: sync.Mutex{}, |
| 365 | + seen: make(map[V]struct{}), |
| 366 | + findStartVertices: func(lg *LabeledGraph[V]) []V { return lg.Roots() }, |
| 367 | + findNextVertices: func(lg *LabeledGraph[V], v V) []V { return lg.children(v) }, |
| 368 | + filterPreviousVerticesByStatus: func(lg *LabeledGraph[V], v V, status string) []V { return lg.filterParents(v, status) }, |
| 369 | + requiredVertexStatus: requiredVertexStatus, |
| 370 | + nextVertexSkipStatus: adjacentVertexSkipStatus, |
| 371 | + processVertex: processVertexFunc, |
| 372 | + } |
| 373 | + return traversal.execute(ctx, lg) |
| 374 | +} |
| 375 | + |
| 376 | +type graphTraversal[V comparable] struct { |
| 377 | + mu sync.Mutex |
| 378 | + seen map[V]struct{} |
| 379 | + findStartVertices func(*LabeledGraph[V]) []V |
| 380 | + findNextVertices func(*LabeledGraph[V], V) []V |
| 381 | + filterPreviousVerticesByStatus func(*LabeledGraph[V], V, string) []V |
| 382 | + requiredVertexStatus string |
| 383 | + nextVertexSkipStatus string |
| 384 | + processVertex func(context.Context, V) error |
| 385 | +} |
| 386 | + |
| 387 | +func (t *graphTraversal[V]) execute(ctx context.Context, lg *LabeledGraph[V]) error { |
| 388 | + |
| 389 | + ctx, cancel := context.WithCancel(ctx) |
| 390 | + defer cancel() |
| 391 | + |
| 392 | + vertexCount := len(lg.vertices) |
| 393 | + if vertexCount == 0 { |
| 394 | + return nil |
| 395 | + } |
| 396 | + eg, ctx := errgroup.WithContext(ctx) |
| 397 | + vertexCh := make(chan V, vertexCount) |
| 398 | + defer close(vertexCh) |
| 399 | + |
| 400 | + processVertices := func(ctx context.Context, graph *LabeledGraph[V], eg *errgroup.Group, vertices []V, vertexCh chan V) { |
| 401 | + for _, vertex := range vertices { |
| 402 | + vertex := vertex |
| 403 | + // Delay processing this vertex if any of its dependent vertices are yet to be processed. |
| 404 | + if len(t.filterPreviousVerticesByStatus(graph, vertex, t.nextVertexSkipStatus)) != 0 { |
| 405 | + continue |
| 406 | + } |
| 407 | + if !t.markAsSeen(vertex) { |
| 408 | + // Skip this vertex if it's already been processed by another routine. |
| 409 | + continue |
| 410 | + } |
| 411 | + eg.Go(func() error { |
| 412 | + if err := t.processVertex(ctx, vertex); err != nil { |
| 413 | + return err |
| 414 | + } |
| 415 | + // Assign new status to the vertex upon successful processing. |
| 416 | + graph.updateStatus(vertex, t.requiredVertexStatus) |
| 417 | + vertexCh <- vertex |
| 418 | + return nil |
| 419 | + }) |
| 420 | + } |
| 421 | + } |
| 422 | + |
| 423 | + eg.Go(func() error { |
| 424 | + for { |
| 425 | + select { |
| 426 | + case <-ctx.Done(): |
| 427 | + return ctx.Err() |
| 428 | + case vertex := <-vertexCh: |
| 429 | + vertexCount-- |
| 430 | + if vertexCount == 0 { |
| 431 | + return nil |
| 432 | + } |
| 433 | + processVertices(ctx, lg, eg, t.findNextVertices(lg, vertex), vertexCh) |
| 434 | + } |
| 435 | + } |
| 436 | + }) |
| 437 | + processVertices(ctx, lg, eg, t.findStartVertices(lg), vertexCh) |
| 438 | + return eg.Wait() |
| 439 | +} |
| 440 | + |
| 441 | +func (t *graphTraversal[V]) markAsSeen(vertex V) bool { |
| 442 | + t.mu.Lock() |
| 443 | + defer t.mu.Unlock() |
| 444 | + if _, seen := t.seen[vertex]; seen { |
| 445 | + return false |
| 446 | + } |
| 447 | + t.seen[vertex] = struct{}{} |
| 448 | + return true |
| 449 | +} |
0 commit comments