Skip to content

Commit 5e444a0

Browse files
authored
Merge pull request #200 from JuliaGPU/tb/warn_scalar
Warn once about scalar operations.
2 parents 3b7979a + 92b9879 commit 5e444a0

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

src/indexing.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,40 @@
1-
# mechanism to disallow indexing
1+
# mechanism to disallow scalar operations
22

3-
const _allowscalar = Ref(true)
3+
const scalar_allowed = Ref(true)
4+
const scalar_warned = Ref(false)
45

5-
allowscalar(flag = true) = (_allowscalar[] = flag)
6+
function allowscalar(flag = true)
7+
scalar_allowed[] = flag
8+
scalar_warned[] = false
9+
return
10+
end
611

712
function assertscalar(op = "operation")
8-
_allowscalar[] || error("$op is disallowed")
9-
return
13+
if !scalar_allowed[]
14+
error("$op is disallowed")
15+
elseif !scalar_warned[]
16+
@warn "Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`"
17+
scalar_warned[] = true
18+
end
19+
return
1020
end
1121

1222
macro allowscalar(ex)
1323
quote
14-
local prev = _allowscalar[]
15-
_allowscalar[] = true
24+
local prev = scalar_allowed[]
25+
scalar_allowed[] = true
1626
local ret = $(esc(ex))
17-
_allowscalar[] = prev
27+
scalar_allowed[] = prev
1828
ret
1929
end
2030
end
2131

2232
macro disallowscalar(ex)
2333
quote
24-
local prev = _allowscalar[]
25-
_allowscalar[] = false
34+
local prev = scalar_allowed[]
35+
scalar_allowed[] = false
2636
local ret = $(esc(ex))
27-
_allowscalar[] = prev
37+
scalar_allowed[] = prev
2838
ret
2939
end
3040
end

0 commit comments

Comments
 (0)