Skip to content

Commit c14f6eb

Browse files
committed
Added the Kronecker product functionality to stdlib_linalg, and added appropriate unit tests for all supported types.
1 parent 4da9933 commit c14f6eb

File tree

3 files changed

+357
-1
lines changed

3 files changed

+357
-1
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ set(fppFiles
2222
stdlib_linalg.fypp
2323
stdlib_linalg_diag.fypp
2424
stdlib_linalg_outer_product.fypp
25+
stdlib_linalg_kronecker.fypp
2526
stdlib_linalg_cross_product.fypp
2627
stdlib_optval.fypp
2728
stdlib_selection.fypp

src/stdlib_linalg.fypp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ module stdlib_linalg
1414
public :: eye
1515
public :: trace
1616
public :: outer_product
17+
public :: kronecker_product
1718
public :: cross_product
1819
public :: is_square
1920
public :: is_diagonal
@@ -93,6 +94,20 @@ module stdlib_linalg
9394
#:endfor
9495
end interface outer_product
9596

97+
interface kronecker_product
98+
!! version: experimental
99+
!!
100+
!! Computes the Kronecker product of two arrays size M1xN1, M2xN2, returning an (M1*M2)x(N1*N2) array
101+
!! ([Specification](../page/specs/stdlib_linalg.html#
102+
!! kronecker_product-computes-the-kronecker-product-of-two-matrices))
103+
#:for k1, t1 in RCI_KINDS_TYPES
104+
pure module function kronecker_product_${t1[0]}$${k1}$(A, B) result(C)
105+
${t1}$, intent(in) :: A(:,:), B(:,:)
106+
${t1}$ :: C(size(A,dim=1)*size(B,dim=1),size(A,dim=2)*size(B,dim=2))
107+
end function kronecker_product_${t1[0]}$${k1}$
108+
#:endfor
109+
end interface kronecker_product
110+
96111

97112
! Cross product (of two vectors)
98113
interface cross_product

test/linalg/test_linalg.fypp

Lines changed: 341 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
module test_linalg
44
use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test
55
use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64
6-
use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product
6+
use stdlib_linalg, only: diag, eye, trace, outer_product, cross_product, kronecker_product
77

88
implicit none
99

@@ -48,6 +48,16 @@ contains
4848
new_unittest("trace_int16", test_trace_int16), &
4949
new_unittest("trace_int32", test_trace_int32), &
5050
new_unittest("trace_int64", test_trace_int64), &
51+
new_unittest("kronecker_product_rsp", test_kronecker_product_rsp), &
52+
new_unittest("kronecker_product_rdp", test_kronecker_product_rdp), &
53+
new_unittest("kronecker_product_rqp", test_kronecker_product_rqp), &
54+
new_unittest("kronecker_product_csp", test_kronecker_product_csp), &
55+
new_unittest("kronecker_product_cdp", test_kronecker_product_cdp), &
56+
new_unittest("kronecker_product_cqp", test_kronecker_product_cqp), &
57+
new_unittest("kronecker_product_int8", test_kronecker_product_iint8), &
58+
new_unittest("kronecker_product_int16", test_kronecker_product_iint16), &
59+
new_unittest("kronecker_product_int32", test_kronecker_product_iint32), &
60+
new_unittest("kronecker_product_int64", test_kronecker_product_iint64), &
5161
new_unittest("outer_product_rsp", test_outer_product_rsp), &
5262
new_unittest("outer_product_rdp", test_outer_product_rdp), &
5363
new_unittest("outer_product_rqp", test_outer_product_rqp), &
@@ -552,6 +562,336 @@ contains
552562
"trace(h) == sum(c(0:nd:2)) failed.")
553563

554564
end subroutine test_trace_int64
565+
566+
subroutine test_kronecker_product_rsp(error)
567+
!> Error handling
568+
type(error_type), allocatable, intent(out) :: error
569+
integer, parameter :: m1=1, n1=2, m2=2, n2=3
570+
real(sp), parameter :: tol = 1.e-6
571+
572+
real(sp) :: A(m1,n1), B(m2,n2)
573+
real(sp) :: C(m1*m2,n1*n2), expected(m1*m2,n1*n2), diff(m1*m2,n1*n2)
574+
575+
integer :: i,j
576+
577+
do j=1, n1
578+
do i=1, m1
579+
A(i,j) = i*j ! A = [1, 2]
580+
end do
581+
end do
582+
583+
do j=1, n2
584+
do i=1, m2
585+
B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]]
586+
end do
587+
end do
588+
589+
C = kronecker_product(A,B)
590+
expected = transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4, 8, 12], [m2*n2, m1*n1]))
591+
diff = C - expected
592+
593+
call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed")
594+
! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]]
595+
596+
end subroutine test_kronecker_product_rsp
597+
598+
subroutine test_kronecker_product_rdp(error)
599+
!> Error handling
600+
type(error_type), allocatable, intent(out) :: error
601+
integer, parameter :: m1=1, n1=2, m2=2, n2=3
602+
real(dp), parameter :: tol = 1.e-6
603+
604+
real(dp) :: A(m1,n1), B(m2,n2)
605+
real(dp) :: C(m1*m2,n1*n2), expected(m1*m2,n1*n2), diff(m1*m2,n1*n2)
606+
607+
integer :: i,j
608+
609+
do j=1, n1
610+
do i=1, m1
611+
A(i,j) = i*j ! A = [1, 2]
612+
end do
613+
end do
614+
615+
do j=1, n2
616+
do i=1, m2
617+
B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]]
618+
end do
619+
end do
620+
621+
C = kronecker_product(A,B)
622+
expected = transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4, 8, 12], [m2*n2, m1*n1]))
623+
diff = C - expected
624+
625+
call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed")
626+
! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]]
627+
628+
end subroutine test_kronecker_product_rdp
629+
630+
subroutine test_kronecker_product_rqp(error)
631+
!> Error handling
632+
type(error_type), allocatable, intent(out) :: error
633+
#:if WITH_QP
634+
integer, parameter :: m1=1, n1=2, m2=2, n2=3
635+
real(qp), parameter :: tol = 1.e-6
636+
637+
real(qp) :: A(m1,n1), B(m2,n2)
638+
real(qp) :: C(m1*m2,n1*n2), expected(m1*m2,n1*n2), diff(m1*m2,n1*n2)
639+
640+
integer :: i,j
641+
642+
do j=1, n1
643+
do i=1, m1
644+
A(i,j) = i*j ! A = [1, 2]
645+
end do
646+
end do
647+
648+
do j=1, n2
649+
do i=1, m2
650+
B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]]
651+
end do
652+
end do
653+
654+
C = kronecker_product(A,B)
655+
expected = transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4, 8, 12], [m2*n2, m1*n1]))
656+
diff = C - expected
657+
658+
call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed")
659+
! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]]
660+
#:else
661+
call skip_test(error, "Quadruple precision is not enabled")
662+
#:endif
663+
664+
end subroutine test_kronecker_product_rqp
665+
666+
subroutine test_kronecker_product_csp(error)
667+
!> Error handling
668+
type(error_type), allocatable, intent(out) :: error
669+
integer, parameter :: m1=1, n1=2, m2=2, n2=3
670+
complex(sp), parameter :: tol = 1.e-6
671+
672+
complex(sp) :: A(m1,n1), B(m2,n2)
673+
complex(sp) :: C(m1*m2,n1*n2), expected(m1*m2,n1*n2), diff(m1*m2,n1*n2)
674+
675+
integer :: i,j
676+
677+
do j=1, n1
678+
do i=1, m1
679+
A(i,j) = i*j ! A = [1, 2]
680+
end do
681+
end do
682+
683+
do j=1, n2
684+
do i=1, m2
685+
B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]]
686+
end do
687+
end do
688+
689+
C = kronecker_product(A,B)
690+
expected = transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4, 8, 12], [m2*n2, m1*n1]))
691+
diff = C - expected
692+
693+
call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed")
694+
! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]]
695+
696+
end subroutine test_kronecker_product_csp
697+
698+
subroutine test_kronecker_product_cdp(error)
699+
!> Error handling
700+
type(error_type), allocatable, intent(out) :: error
701+
integer, parameter :: m1=1, n1=2, m2=2, n2=3
702+
complex(dp), parameter :: tol = 1.e-6
703+
704+
complex(dp) :: A(m1,n1), B(m2,n2)
705+
complex(dp) :: C(m1*m2,n1*n2), expected(m1*m2,n1*n2), diff(m1*m2,n1*n2)
706+
707+
integer :: i,j
708+
709+
do j=1, n1
710+
do i=1, m1
711+
A(i,j) = i*j ! A = [1, 2]
712+
end do
713+
end do
714+
715+
do j=1, n2
716+
do i=1, m2
717+
B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]]
718+
end do
719+
end do
720+
721+
C = kronecker_product(A,B)
722+
expected = transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4, 8, 12], [m2*n2, m1*n1]))
723+
diff = C - expected
724+
725+
call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed")
726+
! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]]
727+
728+
end subroutine test_kronecker_product_cdp
729+
730+
subroutine test_kronecker_product_cqp(error)
731+
!> Error handling
732+
type(error_type), allocatable, intent(out) :: error
733+
#:if WITH_QP
734+
integer, parameter :: m1=1, n1=2, m2=2, n2=3
735+
complex(qp), parameter :: tol = 1.e-6
736+
737+
complex(qp) :: A(m1,n1), B(m2,n2)
738+
complex(qp) :: C(m1*m2,n1*n2), expected(m1*m2,n1*n2), diff(m1*m2,n1*n2)
739+
740+
integer :: i,j
741+
742+
do j=1, n1
743+
do i=1, m1
744+
A(i,j) = i*j ! A = [1, 2]
745+
end do
746+
end do
747+
748+
do j=1, n2
749+
do i=1, m2
750+
B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]]
751+
end do
752+
end do
753+
754+
C = kronecker_product(A,B)
755+
expected = transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4, 8, 12], [m2*n2, m1*n1]))
756+
diff = C - expected
757+
758+
call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed")
759+
! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]]
760+
#:else
761+
call skip_test(error, "Quadruple precision is not enabled")
762+
#:endif
763+
764+
end subroutine test_kronecker_product_cqp
765+
766+
subroutine test_kronecker_product_iint8(error)
767+
!> Error handling
768+
type(error_type), allocatable, intent(out) :: error
769+
integer, parameter :: m1=1, n1=2, m2=2, n2=3
770+
integer(int8), parameter :: tol = 1.e-6
771+
772+
integer(int8) :: A(m1,n1), B(m2,n2)
773+
integer(int8) :: C(m1*m2,n1*n2), expected(m1*m2,n1*n2), diff(m1*m2,n1*n2)
774+
775+
integer :: i,j
776+
777+
do j=1, n1
778+
do i=1, m1
779+
A(i,j) = i*j ! A = [1, 2]
780+
end do
781+
end do
782+
783+
do j=1, n2
784+
do i=1, m2
785+
B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]]
786+
end do
787+
end do
788+
789+
C = kronecker_product(A,B)
790+
expected = transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4, 8, 12], [m2*n2, m1*n1]))
791+
diff = C - expected
792+
793+
call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed")
794+
! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]]
795+
796+
end subroutine test_kronecker_product_iint8
797+
798+
subroutine test_kronecker_product_iint16(error)
799+
!> Error handling
800+
type(error_type), allocatable, intent(out) :: error
801+
integer, parameter :: m1=1, n1=2, m2=2, n2=3
802+
integer(int16), parameter :: tol = 1.e-6
803+
804+
integer(int16) :: A(m1,n1), B(m2,n2)
805+
integer(int16) :: C(m1*m2,n1*n2), expected(m1*m2,n1*n2), diff(m1*m2,n1*n2)
806+
807+
integer :: i,j
808+
809+
do j=1, n1
810+
do i=1, m1
811+
A(i,j) = i*j ! A = [1, 2]
812+
end do
813+
end do
814+
815+
do j=1, n2
816+
do i=1, m2
817+
B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]]
818+
end do
819+
end do
820+
821+
C = kronecker_product(A,B)
822+
expected = transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4, 8, 12], [m2*n2, m1*n1]))
823+
diff = C - expected
824+
825+
call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed")
826+
! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]]
827+
828+
end subroutine test_kronecker_product_iint16
829+
830+
subroutine test_kronecker_product_iint32(error)
831+
!> Error handling
832+
type(error_type), allocatable, intent(out) :: error
833+
integer, parameter :: m1=1, n1=2, m2=2, n2=3
834+
integer(int32), parameter :: tol = 1.e-6
835+
836+
integer(int32) :: A(m1,n1), B(m2,n2)
837+
integer(int32) :: C(m1*m2,n1*n2), expected(m1*m2,n1*n2), diff(m1*m2,n1*n2)
838+
839+
integer :: i,j
840+
841+
do j=1, n1
842+
do i=1, m1
843+
A(i,j) = i*j ! A = [1, 2]
844+
end do
845+
end do
846+
847+
do j=1, n2
848+
do i=1, m2
849+
B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]]
850+
end do
851+
end do
852+
853+
C = kronecker_product(A,B)
854+
expected = transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4, 8, 12], [m2*n2, m1*n1]))
855+
diff = C - expected
856+
857+
call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed")
858+
! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]]
859+
860+
end subroutine test_kronecker_product_iint32
861+
862+
subroutine test_kronecker_product_iint64(error)
863+
!> Error handling
864+
type(error_type), allocatable, intent(out) :: error
865+
integer, parameter :: m1=1, n1=2, m2=2, n2=3
866+
integer(int64), parameter :: tol = 1.e-6
867+
868+
integer(int64) :: A(m1,n1), B(m2,n2)
869+
integer(int64) :: C(m1*m2,n1*n2), expected(m1*m2,n1*n2), diff(m1*m2,n1*n2)
870+
871+
integer :: i,j
872+
873+
do j=1, n1
874+
do i=1, m1
875+
A(i,j) = i*j ! A = [1, 2]
876+
end do
877+
end do
878+
879+
do j=1, n2
880+
do i=1, m2
881+
B(i,j) = i*j ! B = [[1, 2, 3], [2, 4, 6]]
882+
end do
883+
end do
884+
885+
C = kronecker_product(A,B)
886+
887+
expected = transpose(reshape([1,2,3, 2,4,6, 2,4,6, 4, 8, 12], [m2*n2, m1*n1]))
888+
889+
diff = C - expected
890+
891+
call check(error, all(abs(diff) .le. abs(tol)), "all(abs(diff) .le. abs(tol)) failed")
892+
! Expected: C = [1*B, 2*B] = [[1,2,3, 2,4,6], [2,4,6, 4, 8, 12]]
893+
894+
end subroutine test_kronecker_product_iint64
555895

556896

557897
subroutine test_outer_product_rsp(error)

0 commit comments

Comments
 (0)