1+ matrixSqrt = function(
2+ Matrix[Double] X,
3+ boolean iterMethod
4+ )return(
5+ Matrix[Double] sqrt_x
6+ ){
7+ N = nrow(X);
8+ D = ncol(X);
9+
10+ if( D == N ){
11+ # Any non singualar square matrix has a square root
12+ isDiag = isDiagonal(X)
13+ if(isDiag) {
14+ print("diag solution")
15+ sqrt_x = sqrtDiagMatrix(X);
16+ } else {
17+ if(!iterMethod) {
18+ #todo: check if all EigVal positive than possible
19+ sqrt_x = sqrtEigInv(X);
20+ print("eig inv solution")
21+ } else {
22+ # todo: iterative solution
23+ #formular: (Denman–Beavers iteration)
24+ print("iterative solution")
25+ Y = X
26+ #identity matrix
27+ Z = diag(matrix(1.0, rows=N, cols=1))
28+
29+ for (x in 1:10) {
30+ Y_new = (1 / 2) * (Y + inv(Z))
31+ Z_new = (1 / 2) * (Z + inv(Y))
32+ Y = Y_new
33+ Z = Z_new
34+ }
35+ sqrt_x = Y
36+ }
37+ }
38+ } else {
39+ sqrt_x = matrix (0, rows=N, cols=D);
40+ }
41+ }
42+
43+ # assumes square and diagonal matrix
44+ sqrtDiagMatrix = function(
45+ Matrix[Double] X
46+ )return(
47+ Matrix[Double] sqrt_x
48+ ){
49+ N = nrow(X);
50+
51+ sqrt_x = matrix (0, rows=N, cols=N);
52+ for (i in 1:N) {
53+ value = X[i, i];
54+ sqrt_x[i, i] = sqrt(value);
55+ }
56+ }
57+
58+ sqrtEigInv = function(
59+ Matrix[Double] X
60+ )return(
61+ Matrix[Double] sqrt_x
62+ ){
63+ [eValues, eVectors] = eigen(X);
64+ # calculate X = VDV^(-1) -> S = sqrt(D) -> sqrt_x = VSV^(-1)
65+ sqrtD = sqrtDiagMatrix(diag(eValues));
66+ V_Inv = inv(eVectors);
67+ sqrt_x = eVectors %*% sqrtD %*% V_Inv;
68+ }
69+
70+ isDiagonal = function (
71+ Matrix[Double] X
72+ )return(
73+ boolean diagonal
74+ ){
75+ N = nrow(X);
76+ D = ncol(X);
77+ noCells = N * D;
78+
79+ diag = diag(diag(X));
80+ compare = X == diag;
81+ sameCells = sum(compare);
82+
83+ #all cells should be the same to be diagonal
84+ diagonal = noCells == sameCells;
85+ }
86+
87+ # testing area
88+
89+ A = matrix(1, rows=2,cols=2)
90+ B = matrix(4, rows=2,cols=2)
91+
92+ # easy test
93+ B[1,1] = 4.0
94+ B[1,2] = 0.0
95+ B[2,1] = 0.0
96+ B[2,2] = 4.0
97+
98+ res = matrixSqrt(B, FALSE)
99+ print(as.scalar(res[1,1]))
100+ print(as.scalar(res[1,2]))
101+ print(as.scalar(res[2,1]))
102+ print(as.scalar(res[2,2]))
103+
104+ # with this test for diag sqrt
105+ B[1,1] = 16.0
106+ B[1,2] = 21.0
107+ B[2,1] = 28.0
108+ B[2,2] = 37.0
109+
110+ #matrixSqrt(A)
111+ #res = isDiagonal(A)
112+ #print(isDiagonal(A))
113+ res = matrixSqrt(B, FALSE)
114+ print(as.scalar(res[1,1]))
115+ print(as.scalar(res[1,2]))
116+ print(as.scalar(res[2,1]))
117+ print(as.scalar(res[2,2]))
118+
119+ # iter test
120+ B[1,1] = 16.0
121+ B[1,2] = 21.0
122+ B[2,1] = 28.0
123+ B[2,2] = 37.0
124+
125+ res = matrixSqrt(B, TRUE)
126+ print(as.scalar(res[1,1]))
127+ print(as.scalar(res[1,2]))
128+ print(as.scalar(res[2,1]))
129+ print(as.scalar(res[2,2]))
130+
131+ #d = matrix("1.0 0.0
132+ # 0.0 1.0", 2, 2);
133+ #n = 3
134+ #I = diag(matrix(1.0, rows=n, cols=1))
135+ #print(toString(I, decimal=1))
136+
137+ #test sqrt
138+ #C = sqrt(2)
139+ #print(C)
0 commit comments