@@ -634,16 +634,33 @@ for Tri in (:UpperTriangular, :LowerTriangular)
634634end
635635
636636@inline function kron! (C:: AbstractMatrix , A:: Diagonal , B:: Diagonal )
637- valA = A. diag; nA = length (valA )
638- valB = B. diag; nB = length (valB )
637+ valA = A. diag; mA, nA = size (A )
638+ valB = B. diag; mB, nB = size (B )
639639 nC = checksquare (C)
640640 @boundscheck nC == nA* nB ||
641641 throw (DimensionMismatch (lazy " expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)" ))
642- isempty (A) || isempty (B) || fill! (C, zero (A[1 ,1 ] * B[1 ,1 ]))
642+ zerofilled = false
643+ if ! (isempty (A) || isempty (B))
644+ z = A[1 ,1 ] * B[1 ,1 ]
645+ if haszero (typeof (z))
646+ # in this case, the zero is unique
647+ fill! (C, zero (z))
648+ zerofilled = true
649+ end
650+ end
643651 @inbounds for i = 1 : nA, j = 1 : nB
644652 idx = (i- 1 )* nB+ j
645653 C[idx, idx] = valA[i] * valB[j]
646654 end
655+ if ! zerofilled
656+ for j in 1 : nA, i in 1 : mA
657+ Δrow, Δcol = (i- 1 )* mB, (j- 1 )* nB
658+ for k in 1 : nB, l in 1 : mB
659+ i == j && k == l && continue
660+ C[Δrow + l, Δcol + k] = A[i,j] * B[l,k]
661+ end
662+ end
663+ end
647664 return C
648665end
649666
670687 (mC, nC) = size (C)
671688 @boundscheck (mC, nC) == (mA * mB, nA * nB) ||
672689 throw (DimensionMismatch (lazy " expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)" ))
673- isempty (A) || isempty (B) || fill! (C, zero (A[1 ,1 ] * B[1 ,1 ]))
690+ zerofilled = false
691+ if ! (isempty (A) || isempty (B))
692+ z = A[1 ,1 ] * B[1 ,1 ]
693+ if haszero (typeof (z))
694+ # in this case, the zero is unique
695+ fill! (C, zero (z))
696+ zerofilled = true
697+ end
698+ end
674699 m = 1
675700 @inbounds for j = 1 : nA
676701 A_jj = A[j,j]
681706 end
682707 m += (nA - 1 ) * mB
683708 end
709+ if ! zerofilled
710+ # populate the zero elements
711+ for i in 1 : mA
712+ i == j && continue
713+ A_ij = A[i, j]
714+ Δrow, Δcol = (i- 1 )* mB, (j- 1 )* nB
715+ for k in 1 : nB, l in 1 : nA
716+ B_lk = B[l, k]
717+ C[Δrow + l, Δcol + k] = A_ij * B_lk
718+ end
719+ end
720+ end
684721 m += mB
685722 end
686723 return C
@@ -693,17 +730,36 @@ end
693730 (mC, nC) = size (C)
694731 @boundscheck (mC, nC) == (mA * mB, nA * nB) ||
695732 throw (DimensionMismatch (lazy " expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)" ))
696- isempty (A) || isempty (B) || fill! (C, zero (A[1 ,1 ] * B[1 ,1 ]))
733+ zerofilled = false
734+ if ! (isempty (A) || isempty (B))
735+ z = A[1 ,1 ] * B[1 ,1 ]
736+ if haszero (typeof (z))
737+ # in this case, the zero is unique
738+ fill! (C, zero (z))
739+ zerofilled = true
740+ end
741+ end
697742 m = 1
698743 @inbounds for j = 1 : nA
699744 for l = 1 : mB
700745 Bll = B[l,l]
701- for k = 1 : mA
702- C[m] = A[k ,j] * Bll
746+ for i = 1 : mA
747+ C[m] = A[i ,j] * Bll
703748 m += nB
704749 end
705750 m += 1
706751 end
752+ if ! zerofilled
753+ for i in 1 : mA
754+ A_ij = A[i, j]
755+ Δrow, Δcol = (i- 1 )* mB, (j- 1 )* nB
756+ for k in 1 : nB, l in 1 : mB
757+ l == k && continue
758+ B_lk = B[l, k]
759+ C[Δrow + l, Δcol + k] = A_ij * B_lk
760+ end
761+ end
762+ end
707763 m -= nB
708764 end
709765 return C
0 commit comments