Skip to content

Commit aa2164b

Browse files
authored
Fix #107 (#108)
* fix #107 * changelog * fix test * minor
1 parent de35ba7 commit aa2164b

File tree

3 files changed

+100
-14
lines changed

3 files changed

+100
-14
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
OhMyThreads.jl Changelog
22
=========================
33

4+
Version 0.5.3
5+
-------------
6+
- ![Enhancement][badge-enhancement] For the special/fake "macros" like, e.g., `@set`, support the verbose form `OhMyThreads.@set` within a `@tasks` for-loop (#107).
7+
48
Version 0.5.2
59
-------------
610
- ![Enhancement][badge-enhancement] For empty input (e.g. `Float64[]` or `11:10`) behavior is now aligned with the serial functions in `Base`.

src/macro_impl.jl

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,28 @@ using OhMyThreads.Tools: OnlyOneRegion, try_enter!
22
using OhMyThreads.Tools: SimpleBarrier
33
using OhMyThreads: OhMyThreads
44

5+
function _is_special_macro_expr(arg;
6+
lookfor = ("@set", "@local", "@only_one", "@one_by_one", "@barrier"))
7+
if !(arg isa Expr)
8+
return false
9+
end
10+
lookfor_symbols = Symbol.(lookfor)
11+
if arg.head == :macrocall
12+
if arg.args[1] isa Symbol && arg.args[1] in lookfor_symbols
13+
# support, e.g., @set
14+
return true
15+
elseif arg.args[1] isa Expr && arg.args[1].head == Symbol(".")
16+
# support, e.g., OhMyThreads.@set
17+
x = arg.args[1]
18+
if x.args[1] == Symbol("OhMyThreads") && x.args[2] isa QuoteNode &&
19+
x.args[2].value in lookfor_symbols
20+
return true
21+
end
22+
end
23+
end
24+
return false
25+
end
26+
527
function tasks_macro(forex; __module__)
628
if forex.head != :for
729
throw(ErrorException("Expected a for loop after `@tasks`."))
@@ -24,15 +46,7 @@ function tasks_macro(forex; __module__)
2446
# Escape everything in the loop body that is not used in conjuction with one of our
2547
# "macros", e.g. @set or @local. Code inside of these macro blocks will be escaped by
2648
# the respective "macro" handling functions below.
27-
for i in findall(forbody.args) do arg
28-
!(arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@set")) &&
29-
!(arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@local")) &&
30-
!(arg isa Expr && arg.head == :macrocall &&
31-
arg.args[1] == Symbol("@only_one")) &&
32-
!(arg isa Expr && arg.head == :macrocall &&
33-
arg.args[1] == Symbol("@one_by_one")) &&
34-
!(arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@barrier"))
35-
end
49+
for i in findall(!_is_special_macro_expr, forbody.args)
3650
forbody.args[i] = esc(forbody.args[i])
3751
end
3852

@@ -138,7 +152,7 @@ function _maybe_handle_atlocal_block!(args)
138152
locals_before = nothing
139153
local_inner = nothing
140154
tlsidx = findfirst(args) do arg
141-
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@local")
155+
_is_special_macro_expr(arg; lookfor = (Symbol("@local"),))
142156
end
143157
if !isnothing(tlsidx)
144158
locals_before, local_inner = _unfold_atlocal_block(args[tlsidx].args[3])
@@ -198,7 +212,7 @@ end
198212

199213
function _maybe_handle_atset_block!(settings, args)
200214
idcs = findall(args) do arg
201-
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@set")
215+
_is_special_macro_expr(arg; lookfor = (Symbol("@set"),))
202216
end
203217
isnothing(idcs) && return # no @set block found
204218
for i in idcs
@@ -240,7 +254,7 @@ end
240254

241255
function _maybe_handle_atonlyone_blocks!(args)
242256
idcs = findall(args) do arg
243-
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@only_one")
257+
_is_special_macro_expr(arg; lookfor = (Symbol("@only_one"),))
244258
end
245259
isnothing(idcs) && return # no @only_one blocks
246260
setup_onlyone_blocks = quote end
@@ -260,7 +274,7 @@ end
260274

261275
function _maybe_handle_atonebyone_blocks!(args)
262276
idcs = findall(args) do arg
263-
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@one_by_one")
277+
_is_special_macro_expr(arg; lookfor = (Symbol("@one_by_one"),))
264278
end
265279
isnothing(idcs) && return # no @one_by_one blocks
266280
setup_onebyone_blocks = quote end
@@ -280,7 +294,7 @@ end
280294

281295
function _maybe_handle_atbarriers!(args, settings)
282296
idcs = findall(args) do arg
283-
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@barrier")
297+
_is_special_macro_expr(arg; lookfor = (Symbol("@barrier"),))
284298
end
285299
isnothing(idcs) && return # no @barrier found
286300
setup_barriers = quote end

test/runtests.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,4 +597,72 @@ end;
597597
end
598598
end
599599

600+
@testset "verbose special macro usage" begin
601+
# OhMyThreads.@set
602+
@test @tasks(for i in 1:3
603+
OhMyThreads.@set reducer = (+)
604+
i
605+
end) == 6
606+
@test @tasks(for i in 1:3
607+
OhMyThreads.@set begin
608+
reducer = (+)
609+
end
610+
i
611+
end) == 6
612+
# OhMyThreads.@local
613+
ntd = 2 * Threads.nthreads()
614+
@test @tasks(for i in 1:ntd
615+
OhMyThreads.@local x::Ref{Int64} = Ref(0)
616+
OhMyThreads.@set begin
617+
reducer = (+)
618+
scheduler = :static
619+
end
620+
x[] += 1
621+
x[]
622+
end) == @tasks(for i in 1:ntd
623+
@local x::Ref{Int64} = Ref(0)
624+
@set begin
625+
reducer = (+)
626+
scheduler = :static
627+
end
628+
x[] += 1
629+
x[]
630+
end)
631+
# OhMyThreads.@only_one
632+
x = 0
633+
y = 0
634+
try
635+
@tasks for i in 1:10
636+
OhMyThreads.@set ntasks = 10
637+
638+
y += 1 # not safe (race condition)
639+
OhMyThreads.@only_one begin
640+
x += 1 # parallel-safe because only a single task will execute this
641+
end
642+
end
643+
@test x == 1 # only a single task should have incremented x
644+
catch ErrorException
645+
@test false
646+
end
647+
# OhMyThreads.@one_by_one
648+
test_f = () -> begin
649+
sao = SingleAccessOnly()
650+
x = 0
651+
y = 0
652+
@tasks for i in 1:10
653+
OhMyThreads.@set ntasks = 10
654+
655+
y += 1 # not safe (race condition)
656+
OhMyThreads.@one_by_one begin
657+
x += 1 # parallel-safe because inside of one_by_one region
658+
acquire(sao) do
659+
sleep(0.01)
660+
end
661+
end
662+
end
663+
return x
664+
end
665+
@test test_f() == 10
666+
end
667+
600668
# Todo way more testing, and easier tests to deal with

0 commit comments

Comments
 (0)