Skip to content

Commit e6a647f

Browse files
committed
DGraph: Support add_partition of sub-graph
1 parent d67e025 commit e6a647f

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

lib/DaggerGraphs/src/dgraph.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,15 @@ function Graphs.add_vertices!(g::DGraphState, n::Integer)
213213
end
214214
function add_partition!(g::DGraph, n::Integer)
215215
check_not_frozen(g)
216-
with_state(g, add_partition!, n)
216+
return with_state(g, add_partition!, n)
217217
end
218218
function add_partition!(g::DGraphState{T,D}, n::Integer) where {T,D}
219219
check_not_frozen(g)
220220
if n < 1
221221
throw(ArgumentError("n must be >= 1"))
222222
end
223223
push!(g.parts, Dagger.spawn(n) do n
224-
g = D ? SimpleDiGraph() : SimpleGraph()
225-
add_vertices!(g, n)
226-
g
224+
D ? SimpleDiGraph(n) : SimpleGraph(n)
227225
end)
228226
num_v = nv(g)
229227
push!(g.parts_nv, (num_v+1):(num_v+n))
@@ -233,6 +231,18 @@ function add_partition!(g::DGraphState{T,D}, n::Integer) where {T,D}
233231
push!(g.bg_adjs_ne_src, 0)
234232
return length(g.parts)
235233
end
234+
function add_partition!(g::DGraph, sg::AbstractGraph)
235+
check_not_frozen(g)
236+
return with_state(g, add_partition!, sg)
237+
end
238+
function add_partition!(g::DGraphState{T,D}, sg::AbstractGraph) where {T,D}
239+
check_not_frozen(g)
240+
shift = nv(g)
241+
part = add_partition!(g, nv(sg))
242+
part_edges = map(edge->(src(edge)+shift, dst(edge)+shift), collect(edges(sg)))
243+
@assert add_edges!(g, part_edges)
244+
return part
245+
end
236246
function Graphs.add_edge!(g::DGraph, src::Integer, dst::Integer)
237247
check_not_frozen(g)
238248
return with_state(g, add_edge!, src, dst)
@@ -382,3 +392,10 @@ function Graphs.outneighbors(g::DGraphState, v::Integer)
382392
return neighbors
383393
end
384394
Graphs.weights(g::DGraph) = Graphs.DefaultDistance(nv(g))
395+
396+
get_partition(g::DGraph, part::Integer) =
397+
with_state(g, get_partition, part)
398+
get_partition(g::DGraphState, part::Integer) = fetch(g.parts[part])
399+
get_background(g::DGraph, part::Integer) =
400+
with_state(g, get_background, part)
401+
get_background(g::DGraphState, part::Integer) = fetch(g.bg_adjs[part])

0 commit comments

Comments
 (0)