Skip to content

Commit 3b959ef

Browse files
Add multivariate distribution
1 parent 90314c0 commit 3b959ef

File tree

4 files changed

+277
-0
lines changed

4 files changed

+277
-0
lines changed

src/fstats.f90

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ module fstats
2525
public :: f_distribution
2626
public :: chi_squared_distribution
2727
public :: binomial_distribution
28+
public :: multivariate_distribution
29+
public :: multivariate_distribution_function
30+
public :: multivariate_normal_distribution
2831
public :: mean
2932
public :: variance
3033
public :: standard_deviation

src/fstats_distributions.f90

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ module fstats_distributions
33
use ieee_arithmetic
44
use fstats_special_functions
55
use fstats_helper_routines
6+
use ferror
7+
use fstats_errors
68
implicit none
79
private
810
public :: distribution
@@ -13,6 +15,9 @@ module fstats_distributions
1315
public :: f_distribution
1416
public :: chi_squared_distribution
1517
public :: binomial_distribution
18+
public :: multivariate_distribution
19+
public :: multivariate_distribution_function
20+
public :: multivariate_normal_distribution
1621

1722
real(real64), parameter :: pi = 2.0d0 * acos(0.0d0)
1823

@@ -137,6 +142,48 @@ pure function distribution_property(this) result(rst)
137142
procedure, public :: variance => bd_variance
138143
end type
139144

145+
! ******************************************************************************
146+
! MULTIVARIATE DISTRIBUTIONS
147+
! ------------------------------------------------------------------------------
148+
type, abstract :: multivariate_distribution
149+
!! Defines a multivariate probability distribution.
150+
contains
151+
procedure(multivariate_distribution_function), deferred, pass :: pdf
152+
!! Computes the probability density function.
153+
end type
154+
155+
interface
156+
pure function multivariate_distribution_function(this, x) result(rst)
157+
!! Defines an interface for a multivariate probability distribution
158+
!! function.
159+
use iso_fortran_env, only : real64
160+
import multivariate_distribution
161+
class(multivariate_distribution), intent(in) :: this
162+
!! The distribution object.
163+
real(real64), intent(in), dimension(:) :: x
164+
!! The values at which to evaluate the function.
165+
real(real64) :: rst
166+
!! The value of the function.
167+
end function
168+
end interface
169+
170+
! ------------------------------------------------------------------------------
171+
type, extends(multivariate_distribution) :: multivariate_normal_distribution
172+
!! Defines a multivariate normal (Gaussian) distribution.
173+
real(real64), private, allocatable, dimension(:) :: m_means
174+
!! An N-element array of mean values.
175+
real(real64), private, allocatable, dimension(:,:) :: m_cov
176+
!! The N-by-N covariance matrix. This matrix must be
177+
!! positive-definite.
178+
real(real64), private, allocatable, dimension(:,:) :: m_covInv
179+
!! The N-by-N inverse of the covariance matrix.
180+
real(real64), private :: m_covDet
181+
!! The determinant of the covariance matrix.
182+
contains
183+
procedure, public :: initialize => mvnd_init
184+
procedure, public :: pdf => mvnd_pdf
185+
end type
186+
140187
contains
141188
! ------------------------------------------------------------------------------
142189
pure elemental function dist_std_var(this, x) result(rst)
@@ -658,5 +705,189 @@ pure function bd_variance(this) result(rst)
658705
rst = this%n * this%p * (1.0d0 - this%p)
659706
end function
660707

708+
! ******************************************************************************
709+
! MULTIVARIATE NORMAL DISTRIBUTION
710+
! ------------------------------------------------------------------------------
711+
subroutine mvnd_init(this, mu, sigma, err)
712+
use linalg, only : cholesky_factor
713+
!! Initializes the multivariate normal distribution by defining the mean
714+
!! values and covariance matrix.
715+
class(multivariate_normal_distribution), intent(inout) :: this
716+
!! The multivariate_normal_distribution object.
717+
real(real64), intent(in), dimension(:) :: mu
718+
!! An N-element array containing the mean values.
719+
real(real64), intent(in), dimension(:,:) :: sigma
720+
!! The N-by-N covariance matrix. The PDF exists only if this matrix
721+
!! is positive-definite; therefore, the positive-definite constraint
722+
!! is checked within this routine and enforced. An error is thrown if
723+
!! the supplied matrix is not positive-definite.
724+
class(errors), intent(inout), optional, target :: err
725+
!! The error handling object.
726+
727+
! Local Variables
728+
integer(int32) :: n, flag
729+
real(real64), allocatable, dimension(:,:) :: L
730+
class(errors), pointer :: errmgr
731+
type(errors), target :: deferr
732+
733+
! Initialization
734+
if (present(err)) then
735+
errmgr => err
736+
else
737+
errmgr => deferr
738+
end if
739+
n = size(mu)
740+
741+
! Input Checking
742+
if (size(sigma, 1) /= n .or. size(sigma, 2) /= n) then
743+
call report_matrix_size_error(errmgr, "mvnd_init", "sigma", n, n, &
744+
size(sigma, 1), size(sigma, 2))
745+
return
746+
end if
747+
748+
! Store the matrices
749+
this%m_means = mu
750+
this%m_cov = sigma
751+
allocate(L(n, n), stat = flag, source = sigma)
752+
if (flag /= 0) go to 10
753+
if (allocated(this%m_covInv)) then
754+
if (size(this%m_covInv, 1) /= n .or. size(this%m_covInv, 2) /= n) then
755+
deallocate(this%m_covInv)
756+
allocate(this%m_covInv(n, n), stat = flag)
757+
if (flag /= 0) go to 10
758+
end if
759+
else
760+
allocate(this%m_covInv(n, n), stat = flag)
761+
if (flag /= 0) go to 10
762+
end if
763+
764+
! Compute the Cholesky factorization of the covariance matrix
765+
call cholesky_factor(L, upper = .false., err = errmgr)
766+
if (errmgr%has_error_occurred()) return
767+
768+
! Compute the inverse and determinant
769+
call populate_identity(this%m_covInv)
770+
call cholesky_inverse(L, this%m_covInv)
771+
this%m_covDet = cholesky_determinant(L)
772+
773+
! End
774+
return
775+
776+
! Memory Error Handling
777+
10 continue
778+
call report_memory_error(errmgr, "mvnd_init", flag)
779+
return
780+
end subroutine
781+
782+
! ------------------------------------------------------------------------------
783+
pure function mvnd_pdf(this, x) result(rst)
784+
!! Evaluates the PDF for the multivariate normal distribution.
785+
class(multivariate_normal_distribution), intent(in) :: this
786+
!! The multivariate_normal_distribution object.
787+
real(real64), intent(in), dimension(:) :: x
788+
!! The values at which to evaluate the function.
789+
real(real64) :: rst
790+
!! The value of the function.
791+
792+
! Local Variables
793+
integer(int32) :: n
794+
real(real64) :: arg
795+
real(real64), allocatable, dimension(:) :: delta, prod
796+
797+
! Process
798+
n = size(x)
799+
delta = x - this%m_means
800+
prod = matmul(this%m_covInv, delta) ! prod = inv(sigma) * (x - mu)
801+
arg = dot_product(delta, prod) ! arg = (x - mu)**T * prod
802+
rst = exp(-0.5d0 * arg) / sqrt((2.0d0 * pi)**n * this%m_covDet)
803+
end function
804+
805+
! ******************************************************************************
806+
! SUPPORTING ROUTINES
807+
! ------------------------------------------------------------------------------
808+
subroutine cholesky_inverse(x, u)
809+
use linalg, only : solve_triangular_system
810+
!! Computes the inverse of a Cholesky-factored matrix.
811+
real(real64), intent(in), dimension(:,:) :: x
812+
!! The lower-triangular Cholesky factored matrix.
813+
real(real64), intent(inout), dimension(:,:) :: u
814+
!! On input, an N-by-N identity matrix. On output, the N-by-N inverted
815+
!! matrix.
816+
817+
! To compute the inverse of a Cholesky factored matrix (L) consider the
818+
! following:
819+
!
820+
! A = L * L**T
821+
!
822+
! (L * L**T) * inv(A) = I, where I is an identity matrix
823+
!
824+
! First, solve L * U = I, for the N-by-N matrix U
825+
!
826+
! And then solve L' * inv(A) = U for inv(A)
827+
828+
! Solve L * U = I for U
829+
call solve_triangular_system(.true., .false., .false., .true., 1.0d0, x, u)
830+
831+
! Solve L**T * inv(A) = U for inv(A)
832+
call solve_triangular_system(.true., .false., .true., .true., 1.0d0, x, u)
833+
end subroutine
834+
835+
! ------------------------------------------------------------------------------
836+
pure function cholesky_determinant(x) result(rst)
837+
!! Computes the determinant of a Cholesky factored (lower) matrix.
838+
real(real64), intent(in), dimension(:,:) :: x
839+
!! The lower-triangular Cholesky-factored matrix.
840+
real(real64) :: rst
841+
!! The determinant.
842+
843+
! Local Variables
844+
integer(int32) :: i, ep, n
845+
real(real64) :: temp
846+
847+
! Initialization
848+
n = size(x, 1)
849+
rst = 0.0d0
850+
851+
! Compute the product of the squares of the diagonal
852+
temp = 1.0d0
853+
ep = 0
854+
do i = 1, n
855+
temp = (x(i,i))**2 * temp
856+
if (temp == 0.0d0) then
857+
rst = 0.0d0
858+
return
859+
end if
860+
861+
do while (abs(temp) < 1.0d0)
862+
temp = 1.0d1 * temp
863+
ep = ep - 1
864+
end do
865+
866+
do while (abs(temp) > 1.0d1)
867+
temp = 1.0d-1 * temp
868+
ep = ep + 1
869+
end do
870+
end do
871+
rst = temp * (1.0d1)**ep
872+
end function
873+
874+
! ------------------------------------------------------------------------------
875+
subroutine populate_identity(x)
876+
!! Populates the supplied matrix as an identity matrix.
877+
real(real64), intent(inout), dimension(:,:) :: x
878+
879+
! Local Variables
880+
integer(int32) :: i, m, n, mn
881+
882+
! Process
883+
m = size(x, 1)
884+
n = size(x, 2)
885+
mn = min(m, n)
886+
x = 0.0d0
887+
do i = 1, mn
888+
x(i,i) = 1.0d0
889+
end do
890+
end subroutine
891+
661892
! ------------------------------------------------------------------------------
662893
end module

tests/fstats_distribution_tests.f90

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,5 +219,45 @@ function test_standardized_variable() result(rst)
219219
end if
220220
end function
221221

222+
! ------------------------------------------------------------------------------
223+
function test_multivariate_normal_distribution() result(rst)
224+
use linalg, only : mtx_inverse, det
225+
! Arguments
226+
logical :: rst
227+
228+
! Parameters
229+
real(real64), parameter :: pi = 2.0d0 * acos(0.0d0)
230+
real(real64), parameter :: tol = 1.0d-8
231+
232+
! Local Variables
233+
real(real64) :: x(2), mu(2), rho, s1, s2, sigma(2, 2), arg, ans, phi, &
234+
dsig, inv(2, 2)
235+
type(multivariate_normal_distribution) :: dist
236+
237+
! Initialization
238+
rst = .true.
239+
call random_number(x)
240+
call random_number(mu)
241+
call random_number(rho)
242+
call random_number(s1)
243+
call random_number(s2)
244+
sigma = reshape([s2**2, -rho * s1 * s2, -rho * s1 * s2, s1**2], [2, 2])
245+
call dist%initialize(mu, sigma)
246+
247+
! Compute the actual solution
248+
inv = sigma
249+
call mtx_inverse(inv)
250+
arg = -0.5d0 * dot_product(x - mu, matmul(inv, x - mu))
251+
dsig = det(sigma)
252+
ans = exp(arg) / sqrt((2.0d0 * pi)**2 * dsig)
253+
254+
! Test
255+
phi = dist%pdf(x)
256+
if (.not.is_equal(phi, ans, tol)) then
257+
rst = .false.
258+
print "(A)", "TEST FAILED: test_multivariate_normal_distribution -1"
259+
end if
260+
end function
261+
222262
! ------------------------------------------------------------------------------
223263
end module

tests/fstats_tests.f90

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ program tests
3030
local = binomial_distribution_test_1()
3131
if (.not.local) overall = .false.
3232

33+
local = test_multivariate_normal_distribution()
34+
if (.not.local) overall = .false.
35+
3336
! Statistics Tests
3437
local = mean_test_1()
3538
if (.not.local) overall = .false.

0 commit comments

Comments
 (0)