Skip to content

Commit 7785a02

Browse files
committed
Make allowscalar three-state and use for JLArray tests.
1 parent a29df67 commit 7785a02

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

src/host/indexing.jl

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,72 @@ export allowscalar, @allowscalar, assertscalar
55

66
# mechanism to disallow scalar operations
77

8-
const scalar_allowed = Ref(true)
8+
@enum ScalarIndexing ScalarAllowed ScalarWarned ScalarDisallowed
9+
10+
const scalar_allowed = Ref(ScalarWarned)
911
const scalar_warned = Ref(false)
1012

11-
function allowscalar(flag = true)
12-
scalar_allowed[] = flag
13+
"""
14+
allowscalar(allow=true, warn=true)
15+
16+
Configure whether scalar indexing is allowed depending on the value of `allow`.
17+
18+
If allowed, `warn` can be set to throw a single warning instead. Calling this function will
19+
reset the state of the warning, and throw a new warning on subsequent scalar iteration.
20+
"""
21+
function allowscalar(allow::Bool=true, warn::Bool=true)
1322
scalar_warned[] = false
23+
scalar_allowed[] = if allow && !warn
24+
ScalarAllowed
25+
elseif allow
26+
ScalarWarned
27+
else
28+
ScalarDisallowed
29+
end
1430
return
1531
end
1632

33+
"""
34+
assertscalar(op::String)
35+
36+
Assert that a certain operation `op` performs scalar indexing. If this is not allowed, an
37+
error will be thrown ([`allowscalar`](@ref)).
38+
"""
1739
function assertscalar(op = "operation")
18-
if !scalar_allowed[]
40+
if scalar_allowed[] == ScalarDisallowed
1941
error("$op is disallowed")
20-
elseif !scalar_warned[]
42+
elseif scalar_allowed[] == ScalarWarned && !scalar_warned[]
2143
@warn "Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`"
2244
scalar_warned[] = true
2345
end
2446
return
2547
end
2648

49+
"""
50+
@allowscalar ex...
51+
@disallowscalar ex...
52+
53+
Temporarily allow or disallow scalar iteration.
54+
55+
Note that this functionality is intended for functionality that is known and allowed to use
56+
scalar iteration (or not), i.e., there is no option to throw a warning. Only use this on
57+
fine-grained expressions.
58+
"""
2759
macro allowscalar(ex)
2860
quote
2961
local prev = scalar_allowed[]
30-
scalar_allowed[] = true
62+
scalar_allowed[] = ScalarAllowed
3163
local ret = $(esc(ex))
3264
scalar_allowed[] = prev
3365
ret
3466
end
3567
end
3668

69+
@doc (@doc @allowscalar) ->
3770
macro disallowscalar(ex)
3871
quote
3972
local prev = scalar_allowed[]
40-
scalar_allowed[] = false
73+
scalar_allowed[] = ScalarDisallowed
4174
local ret = $(esc(ex))
4275
scalar_allowed[] = prev
4376
ret

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ include("testsuite.jl")
44

55
@testset "JLArray" begin
66
using GPUArrays.JLArrays
7+
JLArrays.allowscalar(false)
78
TestSuite.test(JLArray)
89
end

0 commit comments

Comments
 (0)