Skip to content

Commit aca1263

Browse files
committed
fix planar handling of backend and allocator
1 parent f1f842f commit aca1263

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

src/planar/macros.jl

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,27 @@ function planarparser(planarexpr, kwargs...)
2424
# braiding tensors need to be instantiated before kwargs are processed
2525
push!(parser.preprocessors, _construct_braidingtensors)
2626

27+
# the order of backend and allocator postprocessors are important so let's find them first
28+
hasbackend = false
29+
for (name, val) in kwargs
30+
if name == :backend
31+
hasbackend = true
32+
backend = val
33+
push!(parser.postprocessors, ex -> TO.insertbackend(ex, backend))
34+
break
35+
end
36+
end
37+
for (name, val) in kwargs
38+
if name == :allocator
39+
allocator = val
40+
if !hasbackend
41+
backend = Expr(:call, GlobalRef(TensorOperations, :DefaultBackend))
42+
push!(parser.postprocessors, ex -> TO.insertbackend(ex, backend))
43+
end
44+
push!(parser.postprocessors, ex -> TO.insertallocator(ex, allocator))
45+
break
46+
end
47+
end
2748
for (name, val) in kwargs
2849
if name == :order
2950
isexpr(val, :tuple) ||
@@ -50,16 +71,8 @@ function planarparser(planarexpr, kwargs...)
5071
throw(ArgumentError("Invalid use of `opt`, should be `opt=true` or `opt=OptExpr`"))
5172
end
5273
parser.contractiontreebuilder = network -> TO.optimaltree(network, optdict)[1]
53-
elseif name == :backend
54-
val isa Symbol ||
55-
throw(ArgumentError("Backend should be a symbol."))
56-
push!(parser.postprocessors, ex -> insert_operationbackend(ex, val))
57-
elseif name == :allocator
58-
val isa Symbol ||
59-
throw(ArgumentError("Allocator should be a symbol."))
60-
push!(parser.postprocessors, ex -> TO.insert_allocatorbackend(ex, val))
61-
else
62-
throw(ArgumentError("Unknown keyword argument `name`."))
74+
elseif !(name == :allocator || name == :backend) # already processed
75+
throw(ArgumentError("Unknown keyword argument `$name`."))
6376
end
6477
end
6578

0 commit comments

Comments
 (0)