Skip to content

Commit 03a86d5

Browse files
authored
Merge pull request #231 from janEbert/clamp
Add Base.clamp!
2 parents c9b5f99 + 13aa4ac commit 03a86d5

File tree

4 files changed

+34
-0
lines changed

4 files changed

+34
-0
lines changed

src/GPUArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ include("host/indexing.jl")
3030
include("host/broadcast.jl")
3131
include("host/mapreduce.jl")
3232
include("host/linalg.jl")
33+
include("host/math.jl")
3334
include("host/random.jl")
3435
include("host/quirks.jl")
3536

src/host/math.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import Base.clamp!
2+
3+
function Base.clamp!(A::AbstractGPUArray, low, high)
4+
function kernel(state, A, low, high)
5+
I = @cartesianidx A state
6+
A[I...] = clamp(A[I...], low, high)
7+
return
8+
end
9+
gpu_call(kernel, A, low, high)
10+
return A
11+
end

test/testsuite.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ include("testsuite/vector.jl")
3737
include("testsuite/mapreduce.jl")
3838
include("testsuite/broadcasting.jl")
3939
include("testsuite/linalg.jl")
40+
include("testsuite/math.jl")
4041
include("testsuite/fft.jl")
4142
include("testsuite/random.jl")
4243

@@ -53,6 +54,7 @@ function test(AT::Type{<:AbstractGPUArray})
5354
TestSuite.test_mapreduce(AT)
5455
TestSuite.test_broadcasting(AT)
5556
TestSuite.test_linalg(AT)
57+
TestSuite.test_math(AT)
5658
TestSuite.test_fft(AT)
5759
TestSuite.test_random(AT)
5860
end

test/testsuite/math.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
function test_math(AT)
2+
@testset "math functionality" begin
3+
for ET in supported_eltypes()
4+
# Skip complex numbers
5+
ET in (Complex, ComplexF32, ComplexF64) && continue
6+
7+
T = AT{ET}
8+
@testset "$ET" begin
9+
range = ET <: Integer ? (ET(-2):ET(2)) : ET
10+
low = ET(-1)
11+
high = ET(1)
12+
@testset "clamp!" begin
13+
for N in (2, 10)
14+
@test compare(x -> clamp!(x, low, high), AT, rand(range, N, N))
15+
end
16+
end
17+
end
18+
end
19+
end
20+
end

0 commit comments

Comments
 (0)