Skip to content

Commit cc0badf

Browse files
committed
add option to pin node positions to stress
1 parent 946cfb4 commit cc0badf

File tree

2 files changed

+44
-23
lines changed

2 files changed

+44
-23
lines changed

src/stress.jl

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,16 @@ representation of a network and returns coordinates of the nodes.
3939
4040
- `initialpos=Point{dim,Ptype}[]`
4141
42-
Provide list of initial positions. If length does not match Network size the initial
43-
positions will be truncated or filled up with random normal distributed values in every coordinate.
42+
Provide `Vector` or `Dict` of initial positions. All positions will be
43+
initialized using random coordinates from normal distribution. Random
44+
positions will be overwritten using the key-val-pairs provided by this
45+
argument.
46+
47+
- `pin=[]`: Pin node positions (won't be updated). Can be given as `Vector` or `Dict`
48+
of node index -> value pairings. Values can be either
49+
- `(12, 4.0)` : overwrite initial position and pin
50+
- `true/false` : pin this position
51+
- `(true, false, false)` : only pin certain coordinates
4452
4553
- `seed=1`: Seed for random initial positions.
4654
@@ -68,7 +76,8 @@ The main equation to solve is (8) of:
6876
reltols::FT
6977
abstolx::FT
7078
weights::M
71-
initialpos::Vector{Point{Dim,Ptype}}
79+
initialpos::Dict{Int,Point{Dim,Ptype}}
80+
pin::Dict{Int,SVector{Dim,Bool}}
7281
seed::UInt
7382
end
7483

@@ -79,32 +88,36 @@ function Stress(; dim=2,
7988
reltols=10e-6,
8089
abstolx=10e-6,
8190
weights=Array{Float64}(undef, 0, 0),
82-
initialpos=Point{dim,Ptype}[],
91+
initialpos=[], pin=[],
8392
seed=1)
8493
if !isempty(initialpos)
85-
initialpos = Point.(initialpos)
86-
Ptype = eltype(eltype(initialpos))
87-
# TODO fix initial pos if list has points of multiple types
88-
Ptype == Any && error("Please provide list of Point{N,T} with same T")
89-
dim = length(eltype(initialpos))
94+
dim, Ptype = infer_pointtype(initialpos)
95+
Ptype = promote_type(Float32, Ptype) # make sure to get at least f32 if given as int
9096
end
97+
98+
_initialpos, _pin = _sanitize_initialpos_pin(dim, Ptype, initialpos, pin)
99+
91100
IT, FT, WT = typeof(iterations), typeof(abstols), typeof(weights)
92-
Stress{dim,Ptype,IT,FT,WT}(iterations, abstols, reltols, abstolx, weights, initialpos, seed)
101+
Stress{dim,Ptype,IT,FT,WT}(iterations, abstols, reltols, abstolx, weights, _initialpos, _pin, seed)
93102
end
94103

95104
function Base.iterate(iter::LayoutIterator{<:Stress{Dim,Ptype,IT,FT}}) where {Dim,Ptype,IT,FT}
96105
algo, δ = iter.algorithm, iter.adj_matrix
97106
N = size(δ, 1)
98107
M = length(algo.initialpos)
99108
rng = MersenneTwister(algo.seed)
100-
startpos = Vector{Point{Dim,Ptype}}(undef, N)
101-
# take the first
102-
for i in 1:min(N, M)
103-
startpos[i] = algo.initialpos[i]
109+
startpos = randn(rng, Point{Dim,Ptype}, N)
110+
111+
for (k, v) in algo.initialpos
112+
startpos[k] = v
104113
end
105-
# fill the rest with random points
106-
for i in (M + 1):N
107-
startpos[i] = randn(rng, Point{Dim,Ptype})
114+
115+
if isempty(algo.pin)
116+
pin = nothing
117+
else
118+
isbitstype(Ptype) || error("Pin position only available for isbitstype (got $Ptype)!")
119+
pin = [get(algo.pin, i, SVector{Dim,Bool}(false for _ in 1:Dim)) for i in 1:N]
120+
pin = reinterpret(reshape, Bool, pin)
108121
end
109122

110123
# calculate iteration if :auto
@@ -122,22 +135,30 @@ function Base.iterate(iter::LayoutIterator{<:Stress{Dim,Ptype,IT,FT}}) where {Di
122135
pinvLw = pinv(Lw)
123136
oldstress = stress(startpos, distances, weights)
124137

125-
# the `state` of the iterator is (#iter, old stress, old pos, weights, distances pinvLw, stopflag)
126-
return startpos, (1, oldstress, startpos, weights, distances, pinvLw, maxiter, false)
138+
# the `state` of the iterator is (#iter, old stress, old pos, weights, distances pinvLw, pin, stopflag)
139+
return startpos, (1, oldstress, startpos, weights, distances, pinvLw, maxiter, pin, false)
127140
end
128141

129142
function Base.iterate(iter::LayoutIterator{<:Stress{Dim,Ptype}}, state) where {Dim,Ptype}
130143
algo, δ = iter.algorithm, iter.adj_matrix
131-
i, oldstress, oldpos, weights, distances, pinvLw, maxiter, stopflag = state
144+
i, oldstress, oldpos, weights, distances, pinvLw, maxiter, pin, stopflag = state
132145

133146
if i >= maxiter || stopflag
134147
return nothing
135148
end
136149

137150
# TODO the faster way is to drop the first row and col from the iteration
138151
t = LZ(oldpos, distances, weights)
139-
positions = similar(oldpos) # allocate new array but keep type of oldpos
152+
positions = similar(oldpos)
140153
mul!(positions, pinvLw, (t * oldpos))
154+
155+
if !isnothing(pin)
156+
# on pin positions multiply newpos with zero and add oldpos
157+
_pos = reinterpret(reshape, Ptype, positions)
158+
_oldpos = reinterpret(reshape, Ptype, oldpos)
159+
_pos .= ((!).(pin) .* _pos) + (pin .* _oldpos)
160+
end
161+
141162
@assert all(x -> all(map(isfinite, x)), positions)
142163
newstress = stress(positions, distances, weights)
143164

@@ -147,7 +168,7 @@ function Base.iterate(iter::LayoutIterator{<:Stress{Dim,Ptype}}, state) where {D
147168
stopflag = true
148169
end
149170

150-
return positions, (i + 1, newstress, positions, weights, distances, pinvLw, maxiter, stopflag)
171+
return positions, (i + 1, newstress, positions, weights, distances, pinvLw, maxiter, pin, stopflag)
151172
end
152173

153174
"""

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ jagmesh_adj = jagmesh()
390390
end
391391

392392
@testset "test pin" begin
393-
for algo in [sfdp, spring]
393+
for algo in [sfdp, spring, stress]
394394
g = complete_graph(10)
395395
ep = algo(g; pin=[(0,0), (0,0)])
396396
@test ep[1] == [0,0]

0 commit comments

Comments
 (0)