Skip to content

Commit b2414e0

Browse files
authored
Merge branch 'master' into offset1offsetarrays
2 parents 88be350 + d668a41 commit b2414e0

File tree

5 files changed

+325
-313
lines changed

5 files changed

+325
-313
lines changed

src/ArrayInterface.jl

Lines changed: 3 additions & 234 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ else
2525
end
2626
end
2727

28+
const CanonicalInt = Union{Int,StaticInt}
29+
2830
if VERSION v"1.6.0-DEV.1581"
2931
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,true}}) where {T,N,S,A} = true
3032
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,false}}) where {T,N,S,A} = false
@@ -48,6 +50,7 @@ const LoTri{T,M} = Union{LowerTriangular{T,M},UnitLowerTriangular{T,M}}
4850
@inline static_step(x) = Static.maybe_static(known_step, step, x)
4951

5052
include("ndindex.jl")
53+
include("array_index.jl")
5154

5255
"""
5356
parent_type(::Type{T})
@@ -264,240 +267,6 @@ function findstructralnz(x::Diagonal)
264267
(1:n, 1:n)
265268
end
266269

267-
abstract type MatrixIndex end
268-
269-
struct BidiagonalIndex <: MatrixIndex
270-
count::Int
271-
isup::Bool
272-
end
273-
274-
struct TridiagonalIndex <: MatrixIndex
275-
count::Int# count==nsize+nsize-1+nsize-1
276-
nsize::Int
277-
isrow::Bool
278-
end
279-
280-
struct BandedMatrixIndex <: MatrixIndex
281-
count::Int
282-
rowsize::Int
283-
colsize::Int
284-
bandinds::Array{Int,1}
285-
bandsizes::Array{Int,1}
286-
isrow::Bool
287-
end
288-
289-
function _bandsize(bandind, rowsize, colsize)
290-
-(rowsize - 1) <= bandind <= colsize - 1 || throw(ErrorException("Invalid Bandind"))
291-
if (bandind * (colsize - rowsize) > 0) & (abs(bandind) <= abs(colsize - rowsize))
292-
return min(rowsize, colsize)
293-
elseif bandind * (colsize - rowsize) <= 0
294-
return min(rowsize, colsize) - abs(bandind)
295-
else
296-
return min(rowsize, colsize) - abs(bandind) + abs(colsize - rowsize)
297-
end
298-
end
299-
300-
function BandedMatrixIndex(rowsize, colsize, lowerbandwidth, upperbandwidth, isrow)
301-
upperbandwidth > -lowerbandwidth || throw(ErrorException("Invalid Bandwidths"))
302-
bandinds = upperbandwidth:-1:-lowerbandwidth
303-
bandsizes = [_bandsize(band, rowsize, colsize) for band in bandinds]
304-
BandedMatrixIndex(sum(bandsizes), rowsize, colsize, bandinds, bandsizes, isrow)
305-
end
306-
307-
struct BlockBandedMatrixIndex <: MatrixIndex
308-
count::Int
309-
refinds::Array{Int,1}
310-
refcoords::Array{Int,1}# storing col or row inds at ref points
311-
isrow::Bool
312-
end
313-
314-
function BlockBandedMatrixIndex(nrowblock, ncolblock, rowsizes, colsizes, l, u)
315-
blockrowind = BandedMatrixIndex(nrowblock, ncolblock, l, u, true)
316-
blockcolind = BandedMatrixIndex(nrowblock, ncolblock, l, u, false)
317-
sortedinds = sort(
318-
[(blockrowind[i], blockcolind[i]) for i = 1:length(blockrowind)],
319-
by = x -> x[1],
320-
)
321-
sort!(sortedinds, by = x -> x[2], alg = InsertionSort)# stable sort keeps the second index in order
322-
refinds = Array{Int,1}()
323-
refrowcoords = Array{Int,1}()
324-
refcolcoords = Array{Int,1}()
325-
rowheights = cumsum(pushfirst!(copy(rowsizes), 1))
326-
blockheight = 0
327-
blockrow = 1
328-
blockcol = 1
329-
currenti = 1
330-
lastrowind = sortedinds[1][1] - 1
331-
lastcolind = sortedinds[1][2]
332-
for ind in sortedinds
333-
rowind, colind = ind
334-
if colind == lastcolind
335-
if rowind > lastrowind
336-
blockheight += rowsizes[rowind]
337-
end
338-
else
339-
for j = blockcol:blockcol+colsizes[lastcolind]-1
340-
push!(refinds, currenti)
341-
push!(refrowcoords, blockrow)
342-
push!(refcolcoords, j)
343-
currenti += blockheight
344-
end
345-
blockcol += colsizes[lastcolind]
346-
blockrow = rowheights[rowind]
347-
blockheight = rowsizes[rowind]
348-
end
349-
lastcolind = colind
350-
lastrowind = rowind
351-
end
352-
for j = blockcol:blockcol+colsizes[lastcolind]-1
353-
push!(refinds, currenti)
354-
push!(refrowcoords, blockrow)
355-
push!(refcolcoords, j)
356-
currenti += blockheight
357-
end
358-
push!(refinds, currenti)# guard
359-
push!(refrowcoords, -1)
360-
push!(refcolcoords, -1)
361-
rowindobj = BlockBandedMatrixIndex(currenti - 1, refinds, refrowcoords, true)
362-
colindobj = BlockBandedMatrixIndex(currenti - 1, refinds, refcolcoords, false)
363-
rowindobj, colindobj
364-
end
365-
366-
struct BandedBlockBandedMatrixIndex <: MatrixIndex
367-
count::Int
368-
refinds::Array{Int,1}
369-
refcoords::Array{Int,1}# storing col or row inds at ref points
370-
reflocalinds::Array{BandedMatrixIndex,1}
371-
isrow::Bool
372-
end
373-
374-
function BandedBlockBandedMatrixIndex(
375-
nrowblock,
376-
ncolblock,
377-
rowsizes,
378-
colsizes,
379-
l,
380-
u,
381-
lambda,
382-
mu,
383-
)
384-
blockrowind = BandedMatrixIndex(nrowblock, ncolblock, l, u, true)
385-
blockcolind = BandedMatrixIndex(nrowblock, ncolblock, l, u, false)
386-
sortedinds = sort(
387-
[(blockrowind[i], blockcolind[i]) for i = 1:length(blockrowind)],
388-
by = x -> x[1],
389-
)
390-
sort!(sortedinds, by = x -> x[2], alg = InsertionSort)# stable sort keeps the second index in order
391-
rowheights = cumsum(pushfirst!(copy(rowsizes), 1))
392-
colwidths = cumsum(pushfirst!(copy(colsizes), 1))
393-
currenti = 1
394-
refinds = Array{Int,1}()
395-
refrowcoords = Array{Int,1}()
396-
refcolcoords = Array{Int,1}()
397-
reflocalrowinds = Array{BandedMatrixIndex,1}()
398-
reflocalcolinds = Array{BandedMatrixIndex,1}()
399-
for ind in sortedinds
400-
rowind, colind = ind
401-
localrowind =
402-
BandedMatrixIndex(rowsizes[rowind], colsizes[colind], lambda, mu, true)
403-
localcolind =
404-
BandedMatrixIndex(rowsizes[rowind], colsizes[colind], lambda, mu, false)
405-
push!(refinds, currenti)
406-
push!(refrowcoords, rowheights[rowind])
407-
push!(refcolcoords, colwidths[colind])
408-
push!(reflocalrowinds, localrowind)
409-
push!(reflocalcolinds, localcolind)
410-
currenti += localrowind.count
411-
end
412-
push!(refinds, currenti)
413-
push!(refrowcoords, -1)
414-
push!(refcolcoords, -1)
415-
rowindobj = BandedBlockBandedMatrixIndex(
416-
currenti - 1,
417-
refinds,
418-
refrowcoords,
419-
reflocalrowinds,
420-
true,
421-
)
422-
colindobj = BandedBlockBandedMatrixIndex(
423-
currenti - 1,
424-
refinds,
425-
refcolcoords,
426-
reflocalcolinds,
427-
false,
428-
)
429-
rowindobj, colindobj
430-
end
431-
432-
Base.firstindex(ind::MatrixIndex) = 1
433-
Base.lastindex(ind::MatrixIndex) = ind.count
434-
Base.length(ind::MatrixIndex) = ind.count
435-
function Base.getindex(ind::BidiagonalIndex, i::Int)
436-
1 <= i <= ind.count || throw(BoundsError(ind, i))
437-
if ind.isup
438-
ii = i + 1
439-
else
440-
ii = i + 1 + 1
441-
end
442-
convert(Int, floor(ii / 2))
443-
end
444-
445-
function Base.getindex(ind::TridiagonalIndex, i::Int)
446-
1 <= i <= ind.count || throw(BoundsError(ind, i))
447-
offsetu = ind.isrow ? 0 : 1
448-
offsetl = ind.isrow ? 1 : 0
449-
if 1 <= i <= ind.nsize
450-
return i
451-
elseif ind.nsize < i <= ind.nsize + ind.nsize - 1
452-
return i - ind.nsize + offsetu
453-
else
454-
return i - (ind.nsize + ind.nsize - 1) + offsetl
455-
end
456-
end
457-
458-
function Base.getindex(ind::BandedMatrixIndex, i::Int)
459-
1 <= i <= ind.count || throw(BoundsError(ind, i))
460-
_i = i
461-
p = 1
462-
while _i - ind.bandsizes[p] > 0
463-
_i -= ind.bandsizes[p]
464-
p += 1
465-
end
466-
bandind = ind.bandinds[p]
467-
startfromone = !xor(ind.isrow, (bandind > 0))
468-
if startfromone
469-
return _i
470-
else
471-
return _i + abs(bandind)
472-
end
473-
end
474-
475-
function Base.getindex(ind::BlockBandedMatrixIndex, i::Int)
476-
1 <= i <= ind.count || throw(BoundsError(ind, i))
477-
p = 1
478-
while i - ind.refinds[p] >= 0
479-
p += 1
480-
end
481-
p -= 1
482-
_i = i - ind.refinds[p]
483-
if ind.isrow
484-
return ind.refcoords[p] + _i
485-
else
486-
return ind.refcoords[p]
487-
end
488-
end
489-
490-
function Base.getindex(ind::BandedBlockBandedMatrixIndex, i::Int)
491-
1 <= i <= ind.count || throw(BoundsError(ind, i))
492-
p = 1
493-
while i - ind.refinds[p] >= 0
494-
p += 1
495-
end
496-
p -= 1
497-
_i = i - ind.refinds[p] + 1
498-
ind.reflocalinds[p][_i] + ind.refcoords[p] - 1
499-
end
500-
501270
function findstructralnz(x::Bidiagonal)
502271
n = Base.size(x, 1)
503272
isup = x.uplo == 'U' ? true : false

0 commit comments

Comments
 (0)