Skip to content

Commit 4a82e4a

Browse files
committed
adaptive mixed precision
1 parent 3c168b2 commit 4a82e4a

File tree

3 files changed

+103
-1
lines changed

3 files changed

+103
-1
lines changed

src/Dagger.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ include("array/sort.jl")
7474
include("array/linalg.jl")
7575
include("array/mul.jl")
7676
include("array/cholesky.jl")
77-
77+
include("array/adaptive_mp.jl")
7878
# Visualization
7979
include("visualization.jl")
8080
include("ui/gantt-common.jl")

src/array/adaptive_mp.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
function tile_precision(uplo, global_norm, scalar_factore, tolerance, A)
2+
tile_sqr = 0.0
3+
4+
if uplo == 'G'
5+
tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A)
6+
elseif uplo == 'L'
7+
tile_sqr= mapreduce(LinearAlgebra.norm_sqr, +, LowerTriangular(A))
8+
elseif uplo == 'U'
9+
tile_sqr= mapreduce(LinearAlgebra.norm_sqr, +, UpperTriangular(A))
10+
end
11+
tile_norm = sqrt(tile_sqr)
12+
13+
cal = tile_norm * scalar_factore / global_norm
14+
decision_hp = tile_norm * scalar_factore / global_norm < tolerance / eps(Float16);
15+
decision_sp = tile_norm * scalar_factore / global_norm < tolerance / eps(Float32);
16+
decision_fp8 = tile_norm * scalar_factore / global_norm < tolerance / 0.0625;
17+
18+
if decision_fp8
19+
return "FP8"
20+
elseif decision_hp
21+
return "FP16"
22+
elseif decision_sp
23+
return "FP32"
24+
else
25+
return "FP64"
26+
end
27+
end
28+
29+
function adaptive_mp!(A::UpperTriangular{T,<:DArray{T,2}}, MP::UpperTriangular{String,<:DArray{String,2}}, tolerance::Float64) where T
30+
31+
Ac = parent(A).chunks
32+
MPc= parent(MP).chunks
33+
mt, nt = size(Ac)
34+
35+
global_norm = LinearAlgebra.norm2(A)
36+
37+
for m in range(1, mt)
38+
for n in range(m, nt)
39+
if m==n
40+
MP[m, n] = Dagger.@spawn tile_precision('U', global_norm, max(mt, nt), tolerance, Ac[m, n])
41+
else
42+
MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n])
43+
end
44+
45+
end
46+
end
47+
return UpperTriangular(MP)
48+
end
49+
50+
function adaptive_mp!(A::LowerTriangular{T,<:DArray{T,2}}, MP::LowerTriangular{String,<:DArray{String,2}}, tolerance::Float64) where T
51+
52+
Ac = parent(A).chunks
53+
MPc= parent(MP).chunks
54+
mt, nt = size(Ac)
55+
56+
global_norm = LinearAlgebra.norm2(A)
57+
58+
for m in range(1, mt)
59+
for n in range(1, m)
60+
if m==n
61+
MP[m, n] = Dagger.@spawn tile_precision('L', global_norm, max(mt, nt), tolerance, Ac[m, n])
62+
else
63+
MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n])
64+
end
65+
66+
end
67+
end
68+
return LowerTriangular(MP)
69+
end
70+
71+
72+
function adaptive_mp!(A::DArray{T,2}, MP::DArray{String,2}, tolerance::Float64) where T
73+
74+
Ac = parent(A).chunks
75+
MPc= parent(MP).chunks
76+
mt, nt = size(Ac)
77+
78+
global_norm = LinearAlgebra.norm2(A)
79+
80+
for m in range(1, mt)
81+
for n in range(1, nt)
82+
MP[m, n] = Dagger.@spawn tile_precision('G', global_norm, max(mt, nt), tolerance, Ac[m, n])
83+
end
84+
end
85+
86+
return MP
87+
end
88+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using Dagger
2+
using LinearAlgebra
3+
using KernelFunctions
4+
using Distances
5+
6+
k = GammaExponentialKernel(; γ=0.5, metric=Euclidean());
7+
x = randn(4000, 2000);
8+
A = kernelmatrix(k, x);
9+
DA = view(A, Blocks(400, 400));
10+
MP = fill("FP64", 5, 5);
11+
DMP = view(MP, Blocks(1, 1));
12+
13+
Dagger.adaptive_mp!(DA, DMP, 10^-4);
14+
collect(DMP)

0 commit comments

Comments
 (0)