@@ -7,7 +7,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
7
7
copy, vec, setindex!, count, == , reshape, _throw_dmrs, map, zero
8
8
9
9
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
10
- norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero
10
+ norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular
11
11
12
12
import Base. Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape
13
13
272
272
@inline RectDiagonal {T} (A:: V , args... ) where {T,V} = RectDiagonal {T,V} (A, args... )
273
273
@inline RectDiagonal (A:: V , args... ) where {V} = RectDiagonal {eltype(V),V} (A, args... )
274
274
275
+
276
+ # patch missing overload from Base
277
+ axes (rd:: Diagonal{<:Any,<:AbstractFill} ) = (axes (rd. diag,1 ),axes (rd. diag,1 ))
278
+ axes (T:: AbstractTriangular{<:Any,<:AbstractFill} ) = axes (parent (T))
279
+
275
280
axes (rd:: RectDiagonal ) = rd. axes
276
281
size (rd:: RectDiagonal ) = length .(rd. axes)
277
282
@@ -302,15 +307,23 @@ for f in (:triu, :triu!, :tril, :tril!)
302
307
end
303
308
304
309
310
+ Base. replace_in_print_matrix (A:: RectDiagonal , i:: Integer , j:: Integer , s:: AbstractString ) =
311
+ i == j ? s : Base. replace_with_centered_mark (s)
312
+
313
+
305
314
const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}}
306
315
const SquareEye{T,Axes} = Diagonal{T,Ones{T,1 ,Tuple{Axes}}}
307
316
const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1 ,Tuple{Axes}}}
308
317
309
318
@inline SquareEye {T} (n:: Integer ) where T = Diagonal (Ones {T} (n))
310
319
@inline SquareEye (n:: Integer ) = Diagonal (Ones (n))
320
+ @inline SquareEye {T} (ax:: Tuple{AbstractUnitRange{Int}} ) where T = Diagonal (Ones {T} (ax))
321
+ @inline SquareEye (ax:: Tuple{AbstractUnitRange{Int}} ) = Diagonal (Ones (ax))
311
322
312
- @inline Eye {T} (n:: Integer ) where T = Diagonal (Ones {T} (n))
313
- @inline Eye (n:: Integer ) = Diagonal (Ones (n))
323
+ @inline Eye {T} (n:: Integer ) where T = SquareEye {T} (n)
324
+ @inline Eye (n:: Integer ) = SquareEye (n)
325
+ @inline Eye {T} (ax:: Tuple{AbstractUnitRange{Int}} ) where T = SquareEye {T} (ax)
326
+ @inline Eye (ax:: Tuple{AbstractUnitRange{Int}} ) = SquareEye (ax)
314
327
315
328
# function iterate(iter::Eye, istate = (1, 1))
316
329
# (i::Int, j::Int) = istate
328
341
329
342
Eye (n:: Integer , m:: Integer ) = RectDiagonal (Ones (min (n,m)), n, m)
330
343
Eye {T} (n:: Integer , m:: Integer ) where T = RectDiagonal {T} (Ones {T} (min (n,m)), n, m)
344
+ function Eye {T} ((a,b):: NTuple{2,AbstractUnitRange{Int}} ) where T
345
+ ab = length (a) ≤ length (b) ? a : b
346
+ RectDiagonal {T} (Ones {T} ((ab,)), (a,b))
347
+ end
348
+ function Eye ((a,b):: NTuple{2,AbstractUnitRange{Int}} )
349
+ ab = length (a) ≤ length (b) ? a : b
350
+ RectDiagonal (Ones ((ab,)), (a,b))
351
+ end
352
+
353
+
331
354
@deprecate Eye {T} (sz:: Tuple{Vararg{Integer,2}} ) where T Eye {T} (sz... )
332
355
@deprecate Eye (sz:: Tuple{Vararg{Integer,2}} ) Eye {Float64} (sz... )
333
356
357
+
358
+
334
359
@inline Eye {T} (A:: AbstractMatrix ) where T = Eye {T} (size (A)... )
335
360
@inline Eye (A:: AbstractMatrix ) = Eye {eltype(A)} (size (A)... )
336
361
@@ -506,5 +531,20 @@ include("fillbroadcast.jl")
506
531
Base. replace_in_print_matrix (:: Zeros , :: Integer , :: Integer , s:: AbstractString ) =
507
532
Base. replace_with_centered_mark (s)
508
533
534
+ # following support blocked fill array printing via
535
+ # BlockArrays.jl
536
+ axes_print_matrix_row (_, io, X, A, i, cols, sep) =
537
+ Base. invoke (Base. print_matrix_row, Tuple{IO,AbstractVecOrMat,Vector,Integer,AbstractVector,AbstractString},
538
+ io, X, A, i, cols, sep)
539
+
540
+ Base. print_matrix_row (io:: IO ,
541
+ X:: Union {AbstractFill{<: Any ,1 },
542
+ AbstractFill{<: Any ,2 },
543
+ Diagonal{<: Any ,<: AbstractFill{<:Any,1} },
544
+ RectDiagonal,
545
+ AbstractTriangular{<: Any ,<: AbstractFill{<:Any,2} }
546
+ }, A:: Vector ,
547
+ i:: Integer , cols:: AbstractVector , sep:: AbstractString ) =
548
+ axes_print_matrix_row (axes (X), io, X, A, i, cols, sep)
509
549
510
550
end # module
0 commit comments