Skip to content

Commit 3d43525

Browse files
authored
Compile parametric methods that contain llvmcall (fixes #112, fixes #288) (#289)
1 parent 04d1fb5 commit 3d43525

File tree

4 files changed

+50
-1
lines changed

4 files changed

+50
-1
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1717
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1818
Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
1919
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
20+
Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4"
2021
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2122

2223
[targets]
23-
test = ["Test", "Distributed", "Dates", "SHA", "Mmap"]
24+
test = ["Test", "Dates", "Distributed", "Mmap", "SHA", "Tensors"]

src/construct.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ function prepare_framecode(method::Method, @nospecialize(argtypes); enter_genera
165165
generator = false
166166
end
167167
end
168+
# Currenly, our strategy to deal with llvmcall can't handle parametric functions
169+
# (the "mini interpreter" runs in module scope, not method scope)
170+
if !isempty(lenv) && (hasarg(isequal(:llvmcall), code.code) || hasarg(a->is_global_ref(a, Base, :llvmcall), code.code))
171+
return Compiled()
172+
end
168173
framecode = FrameCode(method, code; generator=generator)
169174
if is_generated(method) && !enter_generated
170175
genframedict[(method, argtypes)] = framecode

src/utils.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@ function scan_ssa_use!(used::BitSet, @nospecialize(stmt))
8787
end
8888
end
8989

90+
function hasarg(predicate, args)
91+
predicate(args) && return true
92+
for a in args
93+
predicate(a) && return true
94+
if isa(a, Expr)
95+
hasarg(predicate, a.args) && return true
96+
elseif isa(a, QuoteNode)
97+
predicate(a.value) && return true
98+
end
99+
end
100+
return false
101+
end
102+
90103
## Predicates
91104

92105
is_goto_node(@nospecialize(node)) = isa(node, GotoNode) || isexpr(node, :gotoifnot)

test/interpret.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,36 @@ function f_mmap()
513513
end
514514
@interpret f_mmap()
515515

516+
# parametric llvmcall (issues #112 and #288)
517+
module VecTest
518+
using Tensors
519+
Vec{N,T} = NTuple{N,VecElement{T}}
520+
# The following test mimic SIMD.jl
521+
const _llvmtypes = Dict{DataType, String}(
522+
Float64 => "double",
523+
Float32 => "float",
524+
Int32 => "i32",
525+
Int64 => "i64"
526+
)
527+
@generated function vecadd(x::Vec{N, T}, y::Vec{N, T}) where {N, T}
528+
llvmT = _llvmtypes[T]
529+
func = T <: AbstractFloat ? "fadd" : "add"
530+
exp = """
531+
%3 = $(func) <$(N) x $(llvmT)> %0, %1
532+
ret <$(N) x $(llvmT)> %3
533+
"""
534+
return quote
535+
Base.@_inline_meta
536+
Core.getfield(Base, :llvmcall)($exp, Vec{$N, $T}, Tuple{Vec{$N, $T}, Vec{$N, $T}}, x, y)
537+
end
538+
end
539+
f() = 1.0 * one(Tensor{2,3})
540+
end
541+
let a = (VecElement{Float64}(1.0), VecElement{Float64}(2.0))
542+
@test @interpret(VecTest.vecadd(a, a)) == VecTest.vecadd(a, a)
543+
end
544+
@test @interpret(VecTest.f()) == [1 0 0; 0 1 0; 0 0 1]
545+
516546
# Test exception type for undefined variables
517547
f() = s = s + 1
518548
@test_throws UndefVarError @interpret f()

0 commit comments

Comments
 (0)