Skip to content

Commit cdb599b

Browse files
committed
Merge branch 'feat/sqrt-in-dml' of github.com:get4flo/systemds into feat/sqrt-in-dml
2 parents dfd4e15 + adbf08e commit cdb599b

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

hello.dml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
A = matrix(1, rows=2,cols=2)
2+
B = matrix(3, rows=2,cols=2)
3+
C = 10
4+
D = A %*% B + C * 2.1
5+
print( "D[1,1]:" + as.scalar(D[1,1]))
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
matrixSqrt = function(
2+
Matrix[Double] X,
3+
boolean iterMethod
4+
)return(
5+
Matrix[Double] sqrt_x
6+
){
7+
N = nrow(X);
8+
D = ncol(X);
9+
10+
if( D == N ){
11+
# Any non singualar square matrix has a square root
12+
isDiag = isDiagonal(X)
13+
if(isDiag) {
14+
print("diag solution")
15+
sqrt_x = sqrtDiagMatrix(X);
16+
} else {
17+
if(!iterMethod) {
18+
#todo: check if all EigVal positive than possible
19+
sqrt_x = sqrtEigInv(X);
20+
print("eig inv solution")
21+
} else {
22+
# todo: iterative solution
23+
#formular: (Denman–Beavers iteration)
24+
print("iterative solution")
25+
Y = X
26+
#identity matrix
27+
Z = diag(matrix(1.0, rows=N, cols=1))
28+
29+
for (x in 1:10) {
30+
Y_new = (1 / 2) * (Y + inv(Z))
31+
Z_new = (1 / 2) * (Z + inv(Y))
32+
Y = Y_new
33+
Z = Z_new
34+
}
35+
sqrt_x = Y
36+
}
37+
}
38+
} else {
39+
sqrt_x = matrix (0, rows=N, cols=D);
40+
}
41+
}
42+
43+
# assumes square and diagonal matrix
44+
sqrtDiagMatrix = function(
45+
Matrix[Double] X
46+
)return(
47+
Matrix[Double] sqrt_x
48+
){
49+
N = nrow(X);
50+
51+
sqrt_x = matrix (0, rows=N, cols=N);
52+
for (i in 1:N) {
53+
value = X[i, i];
54+
sqrt_x[i, i] = sqrt(value);
55+
}
56+
}
57+
58+
sqrtEigInv = function(
59+
Matrix[Double] X
60+
)return(
61+
Matrix[Double] sqrt_x
62+
){
63+
[eValues, eVectors] = eigen(X);
64+
# calculate X = VDV^(-1) -> S = sqrt(D) -> sqrt_x = VSV^(-1)
65+
sqrtD = sqrtDiagMatrix(diag(eValues));
66+
V_Inv = inv(eVectors);
67+
sqrt_x = eVectors %*% sqrtD %*% V_Inv;
68+
}
69+
70+
isDiagonal = function (
71+
Matrix[Double] X
72+
)return(
73+
boolean diagonal
74+
){
75+
N = nrow(X);
76+
D = ncol(X);
77+
noCells = N * D;
78+
79+
diag = diag(diag(X));
80+
compare = X == diag;
81+
sameCells = sum(compare);
82+
83+
#all cells should be the same to be diagonal
84+
diagonal = noCells == sameCells;
85+
}
86+
87+
# testing area
88+
89+
A = matrix(1, rows=2,cols=2)
90+
B = matrix(4, rows=2,cols=2)
91+
92+
# easy test
93+
B[1,1] = 4.0
94+
B[1,2] = 0.0
95+
B[2,1] = 0.0
96+
B[2,2] = 4.0
97+
98+
res = matrixSqrt(B, FALSE)
99+
print(as.scalar(res[1,1]))
100+
print(as.scalar(res[1,2]))
101+
print(as.scalar(res[2,1]))
102+
print(as.scalar(res[2,2]))
103+
104+
# with this test for diag sqrt
105+
B[1,1] = 16.0
106+
B[1,2] = 21.0
107+
B[2,1] = 28.0
108+
B[2,2] = 37.0
109+
110+
#matrixSqrt(A)
111+
#res = isDiagonal(A)
112+
#print(isDiagonal(A))
113+
res = matrixSqrt(B, FALSE)
114+
print(as.scalar(res[1,1]))
115+
print(as.scalar(res[1,2]))
116+
print(as.scalar(res[2,1]))
117+
print(as.scalar(res[2,2]))
118+
119+
# iter test
120+
B[1,1] = 16.0
121+
B[1,2] = 21.0
122+
B[2,1] = 28.0
123+
B[2,2] = 37.0
124+
125+
res = matrixSqrt(B, TRUE)
126+
print(as.scalar(res[1,1]))
127+
print(as.scalar(res[1,2]))
128+
print(as.scalar(res[2,1]))
129+
print(as.scalar(res[2,2]))
130+
131+
#d = matrix("1.0 0.0
132+
# 0.0 1.0", 2, 2);
133+
#n = 3
134+
#I = diag(matrix(1.0, rows=n, cols=1))
135+
#print(toString(I, decimal=1))
136+
137+
#test sqrt
138+
#C = sqrt(2)
139+
#print(C)

0 commit comments

Comments
 (0)