Skip to content

Commit 466bd7e

Browse files
committed
tweak @⌛, sparse matrix adapt
1 parent de0b230 commit 466bd7e

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

src/gpu.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ end
3636
### misc
3737
# the generic versions of these trigger scalar indexing of CUDA, so provide
3838
# specialized versions:
39-
4039
pinv(D::Diagonal{T,<:CuBaseField}) where {T} = Diagonal(@. ifelse(isfinite(inv(D.diag)), inv(D.diag), $zero(T)))
4140
inv(D::Diagonal{T,<:CuBaseField}) where {T} = any(Array((D.diag.==0)[:])) ? throw(SingularException(-1)) : Diagonal(inv.(D.diag))
4241
fill!(f::CuBaseField, x) = (fill!(f.arr,x); f)
@@ -51,12 +50,15 @@ CUDA.sqrt(x::Complex) = CUDA.sqrt(CUDA.abs(x)) * CUDA.exp(im*CUDA.angle(x)/2)
5150
CUDA.culiteral_pow(::typeof(^), x::Complex, ::Val{2}) = x * x
5251
CUDA.pow(x::Complex, p) = x^p
5352

54-
# until https://github.com/JuliaGPU/CUDA.jl/pull/618
53+
# until https://github.com/JuliaGPU/CUDA.jl/pull/618 (CUDA 2.5)
5554
CUDA.cufunc(::typeof(angle)) = CUDA.angle
5655

57-
# this makes cu(::SparseMatrixCSC) return a CuSparseMatrixCSR rather than a
58-
# dense CuArray
59-
adapt_structure(::Type{<:CuArray}, L::SparseMatrixCSC) = CuSparseMatrixCSR(L)
56+
# adapting of SparseMatrixCSC ↔ CuSparseMatrixCSR (otherwise dense arrays created)
57+
adapt_structure(::Type{<:CuArray}, L::SparseMatrixCSC) = CuSparseMatrixCSR(L)
58+
adapt_structure(::Type{<:Array}, L::CuSparseMatrixCSR) = SparseMatrixCSC(L)
59+
adapt_structure(::Type{<:CuArray}, L::CuSparseMatrixCSR) = L
60+
adapt_structure(::Type{<:Array}, L::SparseMatrixCSC) = L
61+
6062

6163
# CUDA somehow missing this one
6264
# see https://github.com/JuliaGPU/CuArrays.jl/issues/103

src/util.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -369,12 +369,13 @@ _isdef(ex) = @capture(ex, function f_(arg__) body_ end)
369369

370370
"""
371371
372-
@⌛ code ...
373-
@⌛ function_definition() = ....
372+
@⌛ [label] code ...
373+
@⌛ [label] function_definition() = ....
374374
375-
Label a section of code to be timed. The first form uses the code
376-
itselfs as a label, the second uses the function name, and its the
377-
body of the function which is timed.
375+
Label a section of code to be timed. If a label string is not
376+
provided, the first form uses the code itselfs as a label, the second
377+
uses the function name, and its the body of the function which is
378+
timed.
378379
379380
To run the timer and print output, returning the result of the
380381
calculation, use
@@ -383,16 +384,27 @@ calculation, use
383384
384385
Timing uses `TimerOutputs.get_defaulttimer()`.
385386
"""
386-
macro (ex)
387+
macro (args...)
388+
if length(args)==1
389+
label, ex = nothing, args[1]
390+
else
391+
label, ex = esc(args[1]), args[2]
392+
end
387393
source_str = last(splitpath(string(__source__.file)))*":"*string(__source__.line)
388394
if _isdef(ex)
389395
sdef = splitdef(ex)
396+
if isnothing(label)
397+
label = "$(string(sdef[:name]))(…) ($source_str)"
398+
end
390399
sdef[:body] = quote
391-
CMBLensing.@timeit $("$(string(sdef[:name]))(…) ($source_str)") $(sdef[:body])
400+
CMBLensing.@timeit $label $(sdef[:body])
392401
end
393402
esc(combinedef(sdef))
394403
else
395-
:(@timeit $("$(Base._truncate_at_width_or_chars(string(prewalk(rmlines,ex)),26)) ($source_str)") $(esc(ex)))
404+
if isnothing(label)
405+
label = "$(Base._truncate_at_width_or_chars(string(prewalk(rmlines,ex)),26)) ($source_str)"
406+
end
407+
:(@timeit $label $(esc(ex)))
396408
end
397409
end
398410

0 commit comments

Comments
 (0)