-
Notifications
You must be signed in to change notification settings - Fork 221
Add stdlib_spatial module with Kabsch–Umeyama vector alignment algorithm #1119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
ea6b295
655af91
60818bb
400b85d
5208695
25df8d2
33859b6
c7b3227
78f35a4
a9c489c
fb08933
88e76fb
e73a93f
e3a1cc2
27b68ca
c895ed5
d722366
b9b9832
37b3cba
723661b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| ADD_EXAMPLE(kabsch_umeyama) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| program example_kabsch_umeyama | ||
| use stdlib_linalg_constants, only: dp | ||
| use stdlib_spatial, only: kabsch_umeyama | ||
| implicit none | ||
|
|
||
| integer, parameter :: d = 2, N = 3 | ||
| real(dp) :: P(d, N), Q(d, N), R(d, d), t(d), c, rmsd | ||
|
|
||
| integer :: i | ||
|
|
||
| P(:,1) = [3.0_dp, -2.0_dp] | ||
| P(:,2) = [7.0_dp, 4.0_dp] | ||
| P(:,3) = [5.0_dp, 0.0_dp] | ||
|
|
||
| Q(:,1) = [2.0_dp, 3.0_dp] | ||
| Q(:,2) = [-1.0_dp, 5.0_dp] | ||
| Q(:,3) = [1.0_dp, 4.0_dp] | ||
|
|
||
| call kabsch_umeyama(P, Q, R, t, c, rmsd) | ||
|
|
||
| print *, "" | ||
| print *, "Recovered rotation R:" | ||
| do i = 1, d | ||
| print *, R(i,:) | ||
| end do | ||
|
|
||
| print *, "" | ||
| print *, "Recovered scale c:", c | ||
|
|
||
| print *, "" | ||
| print *, "Recovered translation t:" | ||
| print *, t | ||
|
|
||
| print *, "" | ||
| print *, "RMSD:", rmsd | ||
|
|
||
| end program example_kabsch_umeyama |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| set(spatial_fppFiles | ||
| stdlib_spatial.fypp | ||
| stdlib_spatial_kabsch_umeyama.fypp | ||
| ) | ||
|
|
||
| set(spatial_cppFiles | ||
| ) | ||
|
|
||
| set(spatial_f90Files | ||
| ) | ||
|
|
||
| configure_stdlib_target(${PROJECT_NAME}_spatial spatial_f90Files spatial_fppFiles spatial_cppFiles) | ||
|
|
||
| target_link_libraries(${PROJECT_NAME}_spatial PUBLIC ${PROJECT_NAME}_constants ${PROJECT_NAME}_linalg_core ${PROJECT_NAME}_linalg ${PROJECT_NAME}_intrinsics) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| #:include "common.fypp" | ||
| #:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX)) | ||
| #:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX)) | ||
| #:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES | ||
| module stdlib_spatial | ||
| use stdlib_linalg_constants | ||
| use stdlib_constants | ||
| use stdlib_error, only: error_stop | ||
| implicit none | ||
| private | ||
| public :: kabsch_umeyama | ||
|
|
||
| interface kabsch_umeyama | ||
| !----------------------------------------------------------------------- | ||
| !> Compute the optimal similarity transform (Kabsch–Umeyama): | ||
| !> | ||
| !> P ≈ c * R * Q + t | ||
| !> | ||
| !> where: | ||
| !> - R is an orthogonal rotation matrix | ||
| !> - c is an optional scale factor | ||
| !> - t is a translation vector | ||
| !> | ||
| !> The transformation minimizes the RMSD between corresponding columns | ||
| !> of P and Q, optionally using weights. | ||
| !----------------------------------------------------------------------- | ||
| #:for k, t, s in (KINDS_TYPES) | ||
| module subroutine kabsch_umeyama_${s}$(P, Q, R, t, c, rmsd, W, scale) | ||
| !> Reference point set (d × N) | ||
| ${t}$, intent(in) :: P(:, :) | ||
| !> Target point set (d × N) | ||
| ${t}$, intent(in) :: Q(:, :) | ||
| !> Optimal rotation matrix (d × d) | ||
| ${t}$, intent(out) :: R(:,:) | ||
| !> Translation vector (d) | ||
| ${t}$, intent(out) :: t(:) | ||
| !> Scale factor | ||
| ${t}$, intent(out) :: c | ||
| !> Root-mean-square deviation | ||
| real(${k}$), intent(out) :: rmsd | ||
| !> Optional weights | ||
| real(${k}$), intent(in), optional :: W(:) | ||
| !> Enable scaling | ||
| logical, intent(in), optional :: scale | ||
| end subroutine | ||
| #:endfor | ||
| end interface | ||
| end module stdlib_spatial |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| #:include "common.fypp" | ||
| #:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX)) | ||
| #:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX)) | ||
| #:set KINDS_TYPES = R_KINDS_TYPES+C_KINDS_TYPES | ||
| submodule(stdlib_spatial) stdlib_spatial_kabsch_umeyama | ||
| use stdlib_linalg, only: svd | ||
| use stdlib_intrinsics, only: stdlib_sum_kahan, stdlib_dot_product_kahan, kahan_kernel | ||
|
|
||
| contains | ||
| #:for k, t, s in (KINDS_TYPES) | ||
| module subroutine kabsch_umeyama_${s}$(P, Q, R, t, c, rmsd, W, scale) | ||
| !> Reference point set (d × N) | ||
| ${t}$, intent(in) :: P(:, :) | ||
| !> Target point set (d × N) | ||
| ${t}$, intent(in) :: Q(:, :) | ||
| !> Optimal rotation matrix (d × d) | ||
| ${t}$, intent(out) :: R(:,:) | ||
| !> Translation vector (d) | ||
| ${t}$, intent(out) :: t(:) | ||
| !> Scale factor | ||
| ${t}$, intent(out) :: c | ||
| !> Root-mean-square deviation | ||
| real(${k}$), intent(out) :: rmsd | ||
| !> Optional weights | ||
| real(${k}$), intent(in), optional :: W(:) | ||
| !> Enable scaling | ||
| logical, intent(in), optional :: scale | ||
|
|
||
|
Comment on lines
23
to
28
|
||
| ! Internal variables. | ||
| integer(ilp) :: i, j, point, d, N | ||
| ${t}$, allocatable :: covariance(:,:), U(:,:), Vt(:,:), vec(:), tmp_N(:), tmp_d(:), c_P(:), c_Q(:) | ||
| real(${k}$) :: sum_w, variance_p | ||
| real(${k}$), allocatable :: S(:) | ||
| logical :: scale_ | ||
| real(${k}$) :: rmsd_err | ||
|
|
||
|
|
||
| ! Dimension checks | ||
| if(size(P,dim=1)/=size(Q,dim=1) .or. size(P,dim=1)/=size(R,dim=1) .or. size(P,dim=1)/=size(R,dim=2) & | ||
| .or. size(P,dim=1)/=size(t)) then | ||
| call error_stop("array sizes do not match") | ||
| end if | ||
| if(size(P,dim=2)/=size(Q,dim=2)) then | ||
| call error_stop("array sizes do not match") | ||
| end if | ||
| if (present(W)) then | ||
| if (size(W) /= size(P,dim=2)) then | ||
| call error_stop("array sizes do not match") | ||
| end if | ||
| end if | ||
| d = size(P,dim=1) | ||
| N = size(P,dim=2) | ||
| scale_ = .true. | ||
| if(present(scale)) scale_ = scale | ||
|
|
||
| sum_w = one_${k}$ / N | ||
| if(present(W)) sum_w = one_${k}$ / stdlib_sum_kahan(W) | ||
|
|
||
| allocate(c_P(d), source=zero_${s}$) | ||
| allocate(c_Q(d), source=zero_${s}$) | ||
| allocate(tmp_N(N), source=zero_${s}$) | ||
|
|
||
| ! Compute centroids of P and Q | ||
| if(present(W)) then | ||
| do i = 1, d | ||
| tmp_N(:) = W(:) * P(i,:) | ||
| c_P(i) = stdlib_sum_kahan(tmp_N) | ||
| tmp_N(:) = W(:) * Q(i,:) | ||
| c_Q(i) = stdlib_sum_kahan(tmp_N) | ||
| end do | ||
| else | ||
| do i = 1, d | ||
| c_P(i) = stdlib_sum_kahan(P(i, :)) | ||
| c_Q(i) = stdlib_sum_kahan(Q(i, :)) | ||
| end do | ||
| end if | ||
| c_P = c_P * sum_w | ||
| c_Q = c_Q * sum_w | ||
|
|
||
| ! Compute covariance matrix H = (P - c_P) * (Q - c_Q)^T and variance of P | ||
| allocate(covariance(d,d), source=zero_${s}$) | ||
| allocate(tmp_d(d), source=zero_${s}$) | ||
| variance_p = zero_${k}$ | ||
|
|
||
| if (present(W)) then | ||
| do point = 1, N | ||
| tmp_d = P(:, point) - c_P(:) | ||
| tmp_N(point) = stdlib_dot_product_kahan(tmp_d, tmp_d) | ||
| end do | ||
| tmp_N(:) = W(:) * tmp_N(:) | ||
| variance_p = stdlib_sum_kahan(tmp_N) | ||
| do j = 1, d | ||
| do i = 1, d | ||
| #:if t.startswith('complex') | ||
| tmp_N(:) = W(:) * (P(i,:) - c_P(i)) * conjg(Q(j,:) - c_Q(j)) | ||
| covariance(i,j) = stdlib_sum_kahan(tmp_N) | ||
| #:else | ||
| tmp_N(:) = W(:) * (P(i,:) - c_P(i)) * (Q(j,:) - c_Q(j)) | ||
| covariance(i,j) = stdlib_sum_kahan(tmp_N) | ||
| #:endif | ||
| end do | ||
| end do | ||
| else | ||
| do point = 1, N | ||
| tmp_d = P(:, point) - c_P(:) | ||
| tmp_N(point) = stdlib_dot_product_kahan(tmp_d, tmp_d) | ||
| end do | ||
| variance_p = stdlib_sum_kahan(tmp_N) | ||
| do j = 1, d | ||
| do i = 1, d | ||
| #:if t.startswith('complex') | ||
| tmp_N(:) = (P(i,:) - c_P(i)) * conjg(Q(j,:) - c_Q(j)) | ||
| covariance(i,j) = stdlib_sum_kahan(tmp_N) | ||
| #:else | ||
| tmp_N(:) = (P(i,:) - c_P(i)) * (Q(j,:) - c_Q(j)) | ||
| covariance(i,j) = stdlib_sum_kahan(tmp_N) | ||
| #:endif | ||
| end do | ||
| end do | ||
| end if | ||
|
|
||
| covariance = covariance * sum_w | ||
| variance_p = variance_p * sum_w | ||
|
|
||
| allocate(U(d,d), source=zero_${s}$) | ||
| allocate(Vt(d,d), source=zero_${s}$) | ||
| allocate(S(d), source=zero_${k}$) | ||
|
|
||
| ! SVD of covariance matrix H -> H = U * S * Vt | ||
| call svd(covariance, S, U, Vt) | ||
|
|
||
| ! Optimal rotation matrix. | ||
| do i = 1,d | ||
| do j = 1,d | ||
| #:if t.startswith('complex') | ||
| R(i,j) = stdlib_dot_product_kahan(conjg(U(i,:)), Vt(:, j)) | ||
| #:else | ||
| R(i,j) = stdlib_dot_product_kahan(U(i,:), Vt(:, j)) | ||
| #:endif | ||
| end do | ||
| end do | ||
|
Comment on lines
+129
to
+141
|
||
|
|
||
| ! Scaling factor | ||
| if(scale_) then | ||
| c = variance_p / (sum(S(1:d))) | ||
| else | ||
| c = one_${s}$ | ||
| end if | ||
|
|
||
| ! Translation vector t = c_P - c*R*c_Q | ||
| do i = 1, d | ||
| #:if t.startswith('complex') | ||
| t(i) = c_P(i) - c * stdlib_dot_product_kahan(conjg(R(i,1:d)), c_Q(1:d)) | ||
| #:else | ||
| t(i) = c_P(i) - c * stdlib_dot_product_kahan(R(i,1:d), c_Q(1:d)) | ||
| #:endif | ||
| end do | ||
|
|
||
| ! Compute RMSD | ||
| allocate(vec(d), source=zero_${s}$) | ||
| rmsd = zero_${k}$ | ||
| rmsd_err = zero_${k}$ | ||
| do point = 1, N | ||
| ! Calculate the k^th difference vector by the formula vec_k = c*R*Q_k + t - P_k | ||
| do i = 1, d | ||
| #:if t.startswith('complex') | ||
| vec(i) = c * stdlib_dot_product_kahan(conjg(R(i,1:d)), Q(1:d,point)) | ||
| #:else | ||
| vec(i) = c * stdlib_dot_product_kahan(R(i,1:d), Q(1:d,point)) | ||
| #:endif | ||
| end do | ||
| vec(1:d) = vec(1:d) + t(1:d) - P(1:d,point) | ||
| call kahan_kernel(real(stdlib_dot_product_kahan(vec,vec), kind=${k}$), rmsd, rmsd_err) | ||
| end do | ||
| rmsd = sqrt(rmsd * sum_w) | ||
| end subroutine | ||
| #:endfor | ||
| end submodule stdlib_spatial_kabsch_umeyama | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| set( | ||
| fppFiles | ||
| "test_spatial_kabsch_umeyama.fypp" | ||
| ) | ||
| fypp_f90pp("${fyppFlags}" "${fppFiles}" outFiles) | ||
|
|
||
| ADDTESTPP(spatial_kabsch_umeyama) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stdlib_spatialusesstdlib_error(error_stop) but the spatial target links only constants/linalg/intrinsics. To avoid undefined references when linking${PROJECT_NAME}_spatialdirectly, add${PROJECT_NAME}_core(or whichever target providesstdlib_error) totarget_link_librarieshere.