Skip to content

Commit 15b9b09

Browse files
committed
NFCs.
1 parent 3dae8b3 commit 15b9b09

File tree

5 files changed

+76
-59
lines changed

5 files changed

+76
-59
lines changed

src/driver.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,13 @@ const __llvm_initialized = Ref(false)
306306
unsafe_delete!(ir, dyn_marker)
307307
end
308308

309+
finish_module!(job, ir)
310+
309311
return ir, (; entry, compiled)
310312
end
311313

312314
@locked function emit_asm(@nospecialize(job::CompilerJob), ir::LLVM.Module;
313315
strip::Bool=false, validate::Bool=true, format::LLVM.API.LLVMCodeGenFileType)
314-
finish_module!(job, ir)
315-
316316
if validate
317317
@timeit_debug to "validation" begin
318318
check_invocation(job)

src/irgen.jl

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,6 @@ function irgen(@nospecialize(job::CompilerJob), method_instance::Core.MethodInst
8181
add_lowering_passes!(job, pm)
8282

8383
run!(pm, mod)
84-
85-
# NOTE: if an optimization is missing, try scheduling an entirely new optimization
86-
# to see which passes need to be added to the target-specific list
87-
# LLVM.clopts("-print-after-all", "-filter-print-funcs=$(LLVM.name(entry))")
88-
# ModulePassManager() do pm
89-
# add_library_info!(pm, triple(mod))
90-
# add_transform_info!(pm, tm)
91-
# PassManagerBuilder() do pmb
92-
# populate!(pm, pmb)
93-
# end
94-
# run!(pm, mod)
95-
# end
9684
end
9785

9886
return mod, compiled
@@ -380,12 +368,10 @@ function deserves_sret(T, llvmT)
380368
end
381369

382370

383-
## byval lowering
384-
385-
# some back-ends don't support byval, or support it badly
371+
# byval lowering
372+
#
373+
# some back-ends don't support byval, or support it badly, so lower it eagerly ourselves
386374
# https://reviews.llvm.org/D79744
387-
388-
# modify the kernel function to fix & improve argument passing
389375
function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
390376
ctx = context(mod)
391377
ft = eltype(llvmtype(f)::LLVM.PointerType)::LLVM.FunctionType

src/optim.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,21 @@ end
99
function addOptimizationPasses!(pm, opt_level=2)
1010
# compare with the using Julia's optimization pipeline directly:
1111
#ccall(:jl_add_optimization_passes, Cvoid,
12-
# (LLVM.API.LLVMPassManagerRef, Cint, Cint),
13-
# pm, opt_level, #=lower_intrinsics=# 0)
12+
# (LLVM.API.LLVMPassManagerRef, Cint, Cint),
13+
# pm, opt_level, #=lower_intrinsics=# 0)
1414
#return
1515

16+
# compate to Clang by using the pass manager builder APIs:
17+
#LLVM.clopts("-print-after-all", "-filter-print-funcs=$(LLVM.name(entry))")
18+
#ModulePassManager() do pm
19+
# add_library_info!(pm, triple(mod))
20+
# add_transform_info!(pm, tm)
21+
# PassManagerBuilder() do pmb
22+
# populate!(pm, pmb)
23+
# end
24+
# run!(pm, mod)
25+
#end
26+
1627
# NOTE: LLVM 12 disabled the hoisting of common instruction
1728
# before loop vectorization (https://reviews.llvm.org/D84108).
1829
#

src/ptx.jl

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -116,47 +116,10 @@ function process_entry!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
116116
mod::LLVM.Module, entry::LLVM.Function)
117117
invoke(process_entry!, Tuple{CompilerJob, LLVM.Module, LLVM.Function}, job, mod, entry)
118118

119-
ctx = context(mod)
120119
if job.source.kernel
121120
# work around bad byval codegen (JuliaGPU/GPUCompiler.jl#92)
122121
entry = lower_byval(job, mod, entry)
123122

124-
# property annotations
125-
annotations = Metadata[entry]
126-
127-
## kernel metadata
128-
append!(annotations, [MDString("kernel"; ctx),
129-
ConstantInt(Int32(1); ctx)])
130-
131-
## expected CTA sizes
132-
if job.target.minthreads !== nothing
133-
for (dim, name) in enumerate([:x, :y, :z])
134-
bound = dim <= length(job.target.minthreads) ? job.target.minthreads[dim] : 1
135-
append!(annotations, [MDString("reqntid$name"; ctx),
136-
ConstantInt(Int32(bound); ctx)])
137-
end
138-
end
139-
if job.target.maxthreads !== nothing
140-
for (dim, name) in enumerate([:x, :y, :z])
141-
bound = dim <= length(job.target.maxthreads) ? job.target.maxthreads[dim] : 1
142-
append!(annotations, [MDString("maxntid$name"; ctx),
143-
ConstantInt(Int32(bound); ctx)])
144-
end
145-
end
146-
147-
if job.target.blocks_per_sm !== nothing
148-
append!(annotations, [MDString("minctasm"; ctx),
149-
ConstantInt(Int32(job.target.blocks_per_sm); ctx)])
150-
end
151-
152-
if job.target.maxregs !== nothing
153-
append!(annotations, [MDString("maxnreg"; ctx),
154-
ConstantInt(Int32(job.target.maxregs); ctx)])
155-
end
156-
157-
push!(metadata(mod)["nvvm.annotations"], MDNode(annotations; ctx))
158-
159-
160123
if LLVM.version() >= v"8"
161124
# calling convention
162125
callconv!(entry, LLVM.API.LLVMPTXKernelCallConv)
@@ -168,6 +131,7 @@ end
168131

169132
function add_lowering_passes!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
170133
pm::LLVM.PassManager)
134+
# hide `unreachable` from LLVM so that it doesn't introduce divergent control flow
171135
if !job.target.unreachable
172136
add!(pm, FunctionPass("HideUnreachable", hide_unreachable!))
173137
end
@@ -208,6 +172,62 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
208172
end
209173
end
210174

175+
function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), mod::LLVM.Module)
176+
ctx = context(mod)
177+
178+
# add metadata annotations for the assembler to the module
179+
# NOTE: we need to do this as late as possible, because otherwise the metadata (which
180+
# refers to a specific function) can get lost when cloning functions. normally
181+
# RAUW updates those references, but we can't RAUW with a changed function type.
182+
if job.source.kernel
183+
# find the entry-point function
184+
# XXX: make this an argument to `emit_asm` again?
185+
entry = nothing
186+
for f in functions(mod)
187+
if callconv(f) == LLVM.API.LLVMPTXKernelCallConv
188+
entry = f
189+
break
190+
end
191+
end
192+
@assert entry !== nothing
193+
194+
# property annotations
195+
annotations = Metadata[entry]
196+
197+
## kernel metadata
198+
append!(annotations, [MDString("kernel"; ctx),
199+
ConstantInt(Int32(1); ctx)])
200+
201+
## expected CTA sizes
202+
if job.target.minthreads !== nothing
203+
for (dim, name) in enumerate([:x, :y, :z])
204+
bound = dim <= length(job.target.minthreads) ? job.target.minthreads[dim] : 1
205+
append!(annotations, [MDString("reqntid$name"; ctx),
206+
ConstantInt(Int32(bound); ctx)])
207+
end
208+
end
209+
if job.target.maxthreads !== nothing
210+
for (dim, name) in enumerate([:x, :y, :z])
211+
bound = dim <= length(job.target.maxthreads) ? job.target.maxthreads[dim] : 1
212+
append!(annotations, [MDString("maxntid$name"; ctx),
213+
ConstantInt(Int32(bound); ctx)])
214+
end
215+
end
216+
217+
if job.target.blocks_per_sm !== nothing
218+
append!(annotations, [MDString("minctasm"; ctx),
219+
ConstantInt(Int32(job.target.blocks_per_sm); ctx)])
220+
end
221+
222+
if job.target.maxregs !== nothing
223+
append!(annotations, [MDString("maxnreg"; ctx),
224+
ConstantInt(Int32(job.target.maxregs); ctx)])
225+
end
226+
227+
push!(metadata(mod)["nvvm.annotations"], MDNode(annotations; ctx))
228+
end
229+
end
230+
211231
function llvm_debug_info(@nospecialize(job::CompilerJob{PTXCompilerTarget}))
212232
# allow overriding the debug info from CUDA.jl
213233
if job.target.debuginfo

test/ptx.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ end
3838
@test !occursin("nvvm.annotations", ir)
3939

4040
ir = sprint(io->ptx_code_llvm(io, kernel, Tuple{};
41-
dump_module=true, kernel=true))
41+
dump_module=true, kernel=true))
4242
@test occursin("nvvm.annotations", ir)
4343
@test !occursin("maxntid", ir)
4444
@test !occursin("reqntid", ir)

0 commit comments

Comments
 (0)