@@ -5,39 +5,72 @@ export allowscalar, @allowscalar, assertscalar
5
5
6
6
# mechanism to disallow scalar operations
7
7
8
- const scalar_allowed = Ref (true )
8
+ @enum ScalarIndexing ScalarAllowed ScalarWarned ScalarDisallowed
9
+
10
+ const scalar_allowed = Ref (ScalarWarned)
9
11
const scalar_warned = Ref (false )
10
12
11
- function allowscalar (flag = true )
12
- scalar_allowed[] = flag
13
+ """
14
+ allowscalar(allow=true, warn=true)
15
+
16
+ Configure whether scalar indexing is allowed depending on the value of `allow`.
17
+
18
+ If allowed, `warn` can be set to throw a single warning instead. Calling this function will
19
+ reset the state of the warning, and throw a new warning on subsequent scalar iteration.
20
+ """
21
+ function allowscalar (allow:: Bool = true , warn:: Bool = true )
13
22
scalar_warned[] = false
23
+ scalar_allowed[] = if allow && ! warn
24
+ ScalarAllowed
25
+ elseif allow
26
+ ScalarWarned
27
+ else
28
+ ScalarDisallowed
29
+ end
14
30
return
15
31
end
16
32
33
+ """
34
+ assertscalar(op::String)
35
+
36
+ Assert that a certain operation `op` performs scalar indexing. If this is not allowed, an
37
+ error will be thrown ([`allowscalar`](@ref)).
38
+ """
17
39
function assertscalar (op = " operation" )
18
- if ! scalar_allowed[]
40
+ if scalar_allowed[] == ScalarDisallowed
19
41
error (" $op is disallowed" )
20
- elseif ! scalar_warned[]
42
+ elseif scalar_allowed[] == ScalarWarned && ! scalar_warned[]
21
43
@warn " Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`"
22
44
scalar_warned[] = true
23
45
end
24
46
return
25
47
end
26
48
49
+ """
50
+ @allowscalar ex...
51
+ @disallowscalar ex...
52
+
53
+ Temporarily allow or disallow scalar iteration.
54
+
55
+ Note that this functionality is intended for functionality that is known and allowed to use
56
+ scalar iteration (or not), i.e., there is no option to throw a warning. Only use this on
57
+ fine-grained expressions.
58
+ """
27
59
macro allowscalar (ex)
28
60
quote
29
61
local prev = scalar_allowed[]
30
- scalar_allowed[] = true
62
+ scalar_allowed[] = ScalarAllowed
31
63
local ret = $ (esc (ex))
32
64
scalar_allowed[] = prev
33
65
ret
34
66
end
35
67
end
36
68
69
+ @doc (@doc @allowscalar ) ->
37
70
macro disallowscalar (ex)
38
71
quote
39
72
local prev = scalar_allowed[]
40
- scalar_allowed[] = false
73
+ scalar_allowed[] = ScalarDisallowed
41
74
local ret = $ (esc (ex))
42
75
scalar_allowed[] = prev
43
76
ret
0 commit comments