Skip to content

Commit 1cac3f2

Browse files
committed
get cu string compilation working
1 parent 5c85f8a commit 1cac3f2

File tree

3 files changed

+64
-53
lines changed

3 files changed

+64
-53
lines changed
Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,34 @@
1-
# EXCLUDE FROM TESTING
1+
#TODO tag CUDArt#103 and use that
22
import CUDArt
3+
export CompileError
34

45
# Generate a temporary file with specific suffix
56
# NOTE: mkstemps is glibc 2.19+, so emulate its behavior
67
function mkstemps(suffix::AbstractString)
78
base = tempname()
89
filename = base * suffix
10+
911
# make sure the filename is unique
1012
i = 0
1113
while isfile(filename)
1214
i += 1
1315
filename = base * ".$i" * suffix
1416
end
17+
1518
return (filename, Base.open(filename, "w"))
1619
end
1720

21+
macro compile(dev, kernel, code)
22+
kernel_name = string(kernel)
23+
containing_file = @__FILE__
24+
25+
return Expr(:toplevel,
26+
Expr(:export,esc(kernel)),
27+
:($(esc(kernel)) = _compile($(esc(dev)), $kernel_name, $code, $containing_file)))
28+
end
1829

19-
type CompileError <: Base.WrappedException
30+
immutable CompileError <: Exception
2031
message::String
21-
error
2232
end
2333

2434
const builddir = joinpath(@__DIR__, ".cache")
@@ -31,7 +41,7 @@ function _compile(dev, kernel, code, containing_file)
3141
mkpath(builddir)
3242
end
3343

34-
# Check if we need to compile
44+
# check if we need to compile
3545
codehash = hex(hash(code))
3646
output = "$builddir/$(kernel)_$(codehash)-$(arch).ptx"
3747
if isfile(output)
@@ -40,47 +50,44 @@ function _compile(dev, kernel, code, containing_file)
4050
need_compile = true
4151
end
4252

43-
# Compile the source, if necessary
53+
# compile the source, if necessary
4454
if need_compile
45-
# Write the source into a compilable file
55+
# write the source to a compilable file
4656
(source, io) = mkstemps(".cu")
4757
write(io, """
4858
extern "C"
4959
{
5060
$code
5161
}
5262
""")
53-
close(io)
63+
Base.close(io)
5464

5565
compile_flags = vcat(CUDArt.toolchain_flags, ["--gpu-architecture", arch])
56-
try
57-
# TODO: capture STDERR
58-
run(pipeline(`$(CUDArt.toolchain_nvcc) $(compile_flags) -ptx -o $output $source`, stderr=DevNull))
59-
catch ex
60-
isa(ex, ErrorException) || rethrow(ex)
61-
rethrow(CompileError("compilation of kernel $kernel failed (typo in C++ source?)", ex))
62-
finally
63-
rm(source)
66+
err = Pipe()
67+
cmd = `$(CUDArt.toolchain_nvcc) $(compile_flags) -ptx -o $output $source`
68+
result = success(pipeline(cmd; stdout=DevNull, stderr=err))
69+
Base.close(err.in)
70+
rm(source)
71+
72+
errors = readstring(err)
73+
if !result
74+
throw(CompileError("compilation of kernel $kernel failed\n$errors"))
75+
elseif !isempty(errors)
76+
warn("during compilation of kernel $kernel:\n$errors")
6477
end
6578

6679
if !isfile(output)
6780
error("compilation of kernel $kernel failed (no output generated)")
6881
end
6982
end
7083

71-
# Pass the module to the CUDA driver
72-
mod = try
73-
CUDAdrv.CuModuleFile(output)
74-
catch ex
75-
rethrow(CompileError("loading of kernel $kernel failed (invalid CUDA code?)", ex))
76-
end
84+
mod = CUDAdrv.CuModuleFile(output)
85+
return CUDAdrv.CuFunction(mod, kernel)
86+
end
7787

78-
# Load the function pointer
79-
func = try
80-
CUDAdrv.CuFunction(mod, kernel)
81-
catch ex
82-
rethrow(CompileError("could not find kernel $kernel in the compiled binary (wrong function name?)", ex))
88+
function clean_cache()
89+
if ispath(builddir)
90+
@assert isdir(builddir)
91+
rm(builddir; recursive=true)
8392
end
84-
85-
return func
8693
end

src/backends/cudanative/cudanative.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ immutable CUFunction{T}
209209
kernel::T
210210
end
211211

212-
if try success(`nvcc --version`); catch false; end
212+
if try success(`$(CUDArt.toolchain_nvcc) --version`); catch false; end
213213
include("compilation.jl")
214214
hasnvcc() = true
215215
else
@@ -236,16 +236,20 @@ function (f::CUFunction{F}){F <: Function, T, N}(A::CUArray{T, N}, args...)
236236
f.kernel, map(unpack_cu_array, args)...
237237
)
238238
end
239-
function cu_convert{T, N}(x::CUArray{T, N})
240-
pointer(buffer(x))
241-
end
242-
cu_convert(x) = x
239+
240+
cudacall_types(x::CUArray{T, N}) where {T, N} = Ptr{T}
241+
cudacall_types(x::T) where T = T
242+
243+
cudacall_convert(x) = x
244+
cudacall_convert(x::CUArray{T, N}) where {T, N} = pointer(buffer(x))
243245

244246
function (f::CUFunction{F}){F <: CUDAdrv.CuFunction, T, N}(A::CUArray{T, N}, args)
245247
griddim, blockdim = thread_blocks_heuristic(A)
246-
CUDAdrv.launch(
247-
f.kernel, CUDAdrv.CuDim3(griddim...), CUDAdrv.CuDim3(blockdim...), 0, CuDefaultStream(),
248-
map(cu_convert, args)
248+
typs = Tuple{cudacall_types.(args)...}
249+
cuargs = cudacall_convert.(args)
250+
CUDAdrv.cudacall(
251+
f.kernel, CUDAdrv.CuDim3(griddim...), CUDAdrv.CuDim3(blockdim...),
252+
typs, cuargs...
249253
)
250254
end
251255

test/cudanative.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,20 @@ end
7373
jy = Array(y)
7474
@test map!(sin, jy, jy) Array(x)
7575
end
76-
#
77-
# if CUBackend.hasnvcc()
78-
# @testset "Custom kernel from string function" begin
79-
# x = GPUArray(rand(Float32, 100))
80-
# y = GPUArray(rand(Float32, 100))
81-
# source = """
82-
# __global__ void copy(const float *input, float *output)
83-
# {
84-
# int i = blockIdx.x * blockDim.x + threadIdx.x;
85-
# output[i] = input[i];
86-
# }
87-
# """
88-
# f = (source, :copy)
89-
# gpu_call(f, x, (x, y))
90-
# @test Array(x) == Array(y)
91-
# end
92-
# end
76+
77+
if CUBackend.hasnvcc()
78+
@testset "Custom kernel from string function" begin
79+
x = GPUArray(rand(Float32, 100))
80+
y = GPUArray(rand(Float32, 100))
81+
source = """
82+
__global__ void copy(const float *input, float *output)
83+
{
84+
int i = blockIdx.x * blockDim.x + threadIdx.x;
85+
output[i] = input[i];
86+
}
87+
"""
88+
f = (source, :copy)
89+
gpu_call(f, x, (x, y))
90+
@test Array(x) == Array(y)
91+
end
92+
end

0 commit comments

Comments
 (0)