Skip to content

Commit 946cfb4

Browse files
committed
add option to pin node positions to spring
1 parent 1c7f6d2 commit 946cfb4

File tree

2 files changed

+34
-25
lines changed

2 files changed

+34
-25
lines changed

src/spring.jl

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,52 +23,59 @@ the nodes.
2323
- `initialtemp=2.0`: Initial "temperature", controls movement per iteration
2424
- `initialpos=Point{dim,Ptype}[]`
2525
26-
Provide list of initial positions. If length does not match Network size the initial
27-
positions will be truncated or filled up with random values between [-1,1] in every coordinate.
26+
Provide `Vector` or `Dict` of initial positions. All positions will be initialized
27+
using random coordinates between [-1,1]. Random positions will be overwritten using
28+
the key-val-pairs provided by this argument.
29+
30+
- `pin=[]`: Pin node positions (won't be updated). Can be given as `Vector` or `Dict`
31+
of node index -> value pairings. Values can be either
32+
- `(12, 4.0)` : overwrite initial position and pin
33+
- `true/false` : pin this position
34+
- `(true, false, false)` : only pin certain coordinates
2835
2936
- `seed=1`: Seed for random initial positions.
3037
"""
3138
@addcall struct Spring{Dim,Ptype} <: IterativeLayout{Dim,Ptype}
3239
C::Float64
3340
iterations::Int
3441
initialtemp::Float64
35-
initialpos::Vector{Point{Dim,Ptype}}
42+
initialpos::Dict{Int,Point{Dim,Ptype}}
43+
pin::Dict{Int,SVector{Dim,Bool}}
3644
seed::UInt
3745
end
3846

39-
function Spring(; dim=2, Ptype=Float64, C=2.0, iterations=100, initialtemp=2.0, initialpos=Point{dim,Ptype}[],
47+
function Spring(; dim=2, Ptype=Float64,
48+
C=2.0, iterations=100, initialtemp=2.0,
49+
initialpos=[], pin=[],
4050
seed=1)
4151
if !isempty(initialpos)
42-
initialpos = Point.(initialpos)
43-
Ptype = eltype(eltype(initialpos))
44-
# TODO fix initial pos if list has points of multiple types
45-
Ptype == Any && error("Please provide list of Point{N,T} with same T")
46-
dim = length(eltype(initialpos))
52+
dim, Ptype = infer_pointtype(initialpos)
53+
Ptype = promote_type(Float32, Ptype) # make sure to get at least f32 if given as int
4754
end
48-
return Spring{dim,Ptype}(C, iterations, initialtemp, initialpos, seed)
55+
_initialpos, _pin = _sanitize_initialpos_pin(dim, Ptype, initialpos, pin)
56+
57+
return Spring{dim,Ptype}(C, iterations, initialtemp, _initialpos, _pin, seed)
4958
end
5059

5160
function Base.iterate(iter::LayoutIterator{<:Spring{Dim,Ptype}}) where {Dim,Ptype}
5261
algo, adj_matrix = iter.algorithm, iter.adj_matrix
5362
N = size(adj_matrix, 1)
54-
M = length(algo.initialpos)
5563
rng = MersenneTwister(algo.seed)
56-
startpos = Vector{Point{Dim,Ptype}}(undef, N)
57-
# take the first
58-
for i in 1:min(N, M)
59-
startpos[i] = algo.initialpos[i]
60-
end
61-
# fill the rest with random points
62-
for i in (M + 1):N
63-
startpos[i] = 2 .* rand(rng, Point{Dim,Ptype}) .- 1
64+
startpos = [2 .* rand(rng, Point{Dim,Ptype}) .- 1 for _ in 1:N]
65+
66+
for (k, v) in algo.initialpos
67+
startpos[k] = v
6468
end
65-
# iteratorstate: #iter nr, old pos
66-
return (startpos, (1, startpos))
69+
70+
pin = [get(algo.pin, i, SVector{Dim,Bool}(false for _ in 1:Dim)) for i in 1:N]
71+
72+
# iteratorstate: #iter nr, old pos, pin
73+
return (startpos, (1, startpos, pin))
6774
end
6875

6976
function Base.iterate(iter::LayoutIterator{<:Spring}, state)
7077
algo, adj_matrix = iter.algorithm, iter.adj_matrix
71-
iteration, old_pos = state
78+
iteration, old_pos, pin = state
7279
iteration >= algo.iterations && return nothing
7380

7481
# The optimal distance bewteen vertices
@@ -108,8 +115,10 @@ function Base.iterate(iter::LayoutIterator{<:Spring}, state)
108115
force_mag = norm(force[i])
109116
iszero(force_mag) && continue
110117
scale = min(force_mag, temp) ./ force_mag
111-
locs[i] += force[i] .* scale
118+
119+
mask = (!).(pin[i]) # where pin=true mask will multiply with 0
120+
locs[i] += force[i] .* scale .* mask
112121
end
113122

114-
return locs, (iteration + 1, locs)
123+
return locs, (iteration + 1, locs, pin)
115124
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ jagmesh_adj = jagmesh()
390390
end
391391

392392
@testset "test pin" begin
393-
for algo in [sfdp]
393+
for algo in [sfdp, spring]
394394
g = complete_graph(10)
395395
ep = algo(g; pin=[(0,0), (0,0)])
396396
@test ep[1] == [0,0]

0 commit comments

Comments
 (0)