Skip to content

Commit 85fa183

Browse files
authored
Merge pull request #251 from JuliaGPU/tb/state_byval
More kernel state optimizations
2 parents 8db1129 + 22c62e8 commit 85fa183

File tree

9 files changed

+83
-145
lines changed

9 files changed

+83
-145
lines changed

Manifest.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,17 @@ version = "1.3.0"
3939

4040
[[LLVM]]
4141
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
42-
git-tree-sha1 = "756cd7ea042f82962d8d46e378c0f1863bb4dc0f"
42+
git-tree-sha1 = "46092047ca4edc10720ecab437c42283cd7c44f3"
43+
repo-rev = "master"
44+
repo-url = "https://github.com/maleadt/LLVM.jl.git"
4345
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
44-
version = "4.5.3"
46+
version = "4.6.0"
4547

4648
[[LLVMExtra_jll]]
4749
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
48-
git-tree-sha1 = "9c360e5ce980b88bb31a7b086dbb19469008154b"
50+
git-tree-sha1 = "6a2af408fe809c4f1a54d2b3f188fdd3698549d6"
4951
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
50-
version = "0.0.10+0"
52+
version = "0.0.11+0"
5153

5254
[[LibCURL]]
5355
deps = ["LibCURL_jll", "MozillaCACerts_jll"]

src/gcn.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ function lower_throw_extra!(mod::LLVM.Module)
122122
end
123123

124124
# remove the call
125-
call_args = operands(call)[1:end-1] # last arg is function itself
125+
nargs = length(parameters(f))
126+
call_args = arguments(call)
126127
unsafe_delete!(LLVM.parent(call), call)
127128

128129
# HACK: kill the exceptions' unused arguments

src/irgen.jl

Lines changed: 60 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ function lower_throw!(mod::LLVM.Module)
195195
end
196196

197197
# remove the call
198-
call_args = operands(call)[1:end-1] # last arg is function itself
198+
call_args = arguments(call)
199199
unsafe_delete!(LLVM.parent(call), call)
200200

201201
# HACK: kill the exceptions' unused arguments
@@ -406,9 +406,6 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
406406
byval[arg.codegen.i+has_kernel_state] = true
407407
end
408408
end
409-
if has_kernel_state
410-
byval[1] = true
411-
end
412409
end
413410

414411
# fixup metadata
@@ -549,8 +546,16 @@ end
549546
# the storage location for exception flags, or the location of a I/O buffer, we enable the
550547
# back-end to specify a Julia object that will be passed to the kernel by-value, and to
551548
# every called function by-reference. Access to this object is done using the
552-
# `julia.gpu.state_getter` intrinsic, which returns an opaque pointer to the state object.
553-
# after optimization, these intrinsics will be lowered to refer to the state argument.
549+
# `julia.gpu.state_getter` intrinsic. after optimization, these intrinsics will be lowered
550+
# to refer to the state argument.
551+
#
552+
# note that we deviate from the typical Julia calling convention, by always passing the
553+
# state objects by value instead of by reference, this to ensure that the state object
554+
# is not copied to the stack (because LLVM doesn't see that all uses are read-only).
555+
# in principle, `readonly byval` should be equivalent, but LLVM doesn't realize that.
556+
# also see https://github.com/JuliaGPU/CUDA.jl/pull/1167 and the comments in that PR.
557+
# once LLVM supports this pattern, consider going back to passing the state by reference,
558+
# so that the julia.gpu.state_getter` can be simplified to return an opaque pointer.
554559

555560
# add a state argument to every function in the module, starting from the kernel entry point
556561
function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
@@ -565,13 +570,10 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
565570
return false
566571
end
567572
T_state = convert(LLVMType, state; ctx)
568-
T_ptr_state = LLVM.PointerType(T_state)
569573

570574
# intrinsic returning an opaque pointer to the kernel state.
571575
# this is both for extern uses, and to make this transformation a two-step process.
572-
T_int8 = LLVM.IntType(8; ctx)
573-
T_pint8 = LLVM.PointerType(T_int8)
574-
state_intr = kernel_state_intr(mod)
576+
state_intr = kernel_state_intr(mod, T_state)
575577

576578
# add a state argument to every function
577579
worklist = filter(!isdeclaration, collect(functions(mod)))
@@ -582,7 +584,7 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
582584
LLVM.name!(f, fn * ".stateless")
583585

584586
# create a new function
585-
new_param_types = [T_ptr_state, parameters(ft)...]
587+
new_param_types = [T_state, parameters(ft)...]
586588
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
587589
new_f = LLVM.Function(mod, fn, new_ft)
588590
LLVM.name!(parameters(new_f)[1], "state")
@@ -618,16 +620,6 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
618620
clone_into!(new_f, f; value_map, materializer,
619621
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
620622

621-
# pass the state by value to the kernel (after cloning, which overwrites attributes)
622-
if f == entry
623-
attr = if LLVM.version() >= v"12"
624-
TypeAttribute("byval", T_state; ctx)
625-
else
626-
EnumAttribute("byval", 0; ctx)
627-
end
628-
push!(parameter_attributes(new_f, 1), attr)
629-
end
630-
631623
# we can't remove this function yet, as we might still need to rewrite any called,
632624
# but remove the IR already
633625
empty!(f)
@@ -656,10 +648,9 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
656648

657649
# forward the state argument
658650
position!(builder, val)
659-
untyped_state = call!(builder, state_intr, Value[], "state")
660-
typed_state = bitcast!(builder, untyped_state, T_ptr_state)
651+
state = call!(builder, state_intr, Value[], "state")
661652
new_val = if val isa LLVM.CallInst
662-
call!(builder, new_f, [typed_state, operands(val)[1:end-1]...])
653+
call!(builder, new_f, [state, arguments(val)...], operand_bundles(val))
663654
else
664655
# TODO: invoke and callbr
665656
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
@@ -695,30 +686,12 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
695686
unsafe_delete!(mod, f)
696687
end
697688

698-
# HACK: add a dummy use of the kernel state pointer to ensure it survives optimization
699-
dummy_user = if haskey(functions(mod), "julia.gpu.state_user")
700-
functions(mod)["julia.gpu.state_user"]
701-
else
702-
LLVM.Function(mod, "julia.gpu.state_user",
703-
LLVM.FunctionType(LLVM.VoidType(ctx), [T_ptr_state]))
704-
end
705-
entry = functions(mod)[entry_fn]
706-
Builder(ctx) do builder
707-
position!(builder, first(instructions(first(blocks(entry)))))
708-
call!(builder, dummy_user, [parameters(entry)[1]])
709-
end
710-
711689
return true
712690
end
713691

714692
# lower calls to the state getter intrinsic. this is a two-step process, so that the state
715693
# argument can be added before optimization, and that optimization can introduce new uses
716694
# before the intrinsic getting lowered late during optimization.
717-
#
718-
# the reason we want to add the state argument before optimization, is that the initial
719-
# argument is marked byval, but some backends need to eagerly lower that byval property
720-
# (because the LLVM back-end doesn't support emitting code for it). That lowering typically
721-
# generates a lot of expensive code, so _needs_ to be optimized.
722695
function lower_kernel_state!(fun::LLVM.Function)
723696
job = current_job::CompilerJob
724697
mod = LLVM.parent(fun)
@@ -731,64 +704,33 @@ function lower_kernel_state!(fun::LLVM.Function)
731704
return false
732705
end
733706

734-
# find the kernel state argument. normally, this is the first argument of the function.
735-
state_arg = nothing
707+
# find the kernel state argument. this should be the first argument of the function.
708+
state_arg = parameters(fun)[1]
736709
T_state = convert(LLVMType, state; ctx)
737-
T_ptr_state = LLVM.PointerType(T_state)
738-
first_arg = parameters(fun)[1]
739-
if llvmtype(first_arg) == T_ptr_state
740-
state_arg = first_arg
741-
end
742-
743-
# with kernels, the story is more complicated: the kernel state argument is marked byval,
744-
# and it's possible we eagerly lowered that pointer to a value. to retrieve the state,
745-
# look for the alloca slot the argument was stored in via the dummy use we introduced.
746-
#
747-
# this is obviously a hack, stemming from the fact that while lowering Julia intrinsics
748-
# (which needs to happen _after_ optimization) we may have to emit calls to the GPU
749-
# runtime while those functions may already have had their kernel state arguments added
750-
# (which we do _before_ optimization to make sure that any lowered byval performs well).
751-
if state_arg === nothing
752-
@assert llvmtype(first_arg) == T_state
753-
dummy_user = functions(mod)["julia.gpu.state_user"]
754-
for use in uses(dummy_user)
755-
call = user(use)
756-
bb = LLVM.parent(call)
757-
if LLVM.parent(bb) == fun
758-
state_arg = operands(call)[1]
759-
break
760-
end
761-
end
762-
end
763-
764-
if state_arg === nothing
765-
error("Internal compiler error: could not reconstruct kernel state argument")
766-
end
767-
768-
# get the intrinsic returning an opaque pointer to the kernel state.
769-
T_int8 = LLVM.IntType(8; ctx)
770-
T_pint8 = LLVM.PointerType(T_int8)
771-
state_intr = kernel_state_intr(mod)
710+
@assert llvmtype(state_arg) == T_state
772711

773712
# fixup all uses of the state getter to use the newly introduced function state argument
774-
Builder(ctx) do builder
775-
for use in uses(state_intr)
776-
inst = user(use)
777-
@assert inst isa LLVM.CallInst
778-
bb = LLVM.parent(inst)
779-
LLVM.parent(bb) == fun || continue
713+
if haskey(functions(mod), "julia.gpu.state_getter")
714+
state_intr = functions(mod)["julia.gpu.state_getter"]
715+
716+
Builder(ctx) do builder
717+
for use in uses(state_intr)
718+
inst = user(use)
719+
@assert inst isa LLVM.CallInst
720+
bb = LLVM.parent(inst)
721+
LLVM.parent(bb) == fun || continue
780722

781-
position!(builder, inst)
782-
bb = LLVM.parent(inst)
783-
f = LLVM.parent(bb)
723+
position!(builder, inst)
724+
bb = LLVM.parent(inst)
725+
f = LLVM.parent(bb)
784726

785-
untyped_state = bitcast!(builder, state_arg, T_pint8)
786-
replace_uses!(inst, untyped_state)
727+
replace_uses!(inst, state_arg)
787728

788-
@assert isempty(uses(inst))
789-
unsafe_delete!(LLVM.parent(inst), inst)
729+
@assert isempty(uses(inst))
730+
unsafe_delete!(LLVM.parent(inst), inst)
790731

791-
changed = true
732+
changed = true
733+
end
792734
end
793735
end
794736

@@ -810,45 +752,44 @@ function cleanup_kernel_state!(mod::LLVM.Module)
810752
end
811753
end
812754

813-
# remove the kernel state dummy use
814-
if haskey(functions(mod), "julia.gpu.state_user")
815-
intr = functions(mod)["julia.gpu.state_user"]
816-
for use in uses(intr)
817-
call = user(use)
818-
unsafe_delete!(LLVM.parent(call), call)
819-
end
820-
@assert isempty(uses(intr))
821-
unsafe_delete!(mod, intr)
822-
changed = true
823-
end
824-
825755
return changed
826756
end
827757

828-
function kernel_state_intr(mod::LLVM.Module)
758+
function kernel_state_intr(mod::LLVM.Module, T_state)
829759
ctx = context(mod)
830-
T_int8 = LLVM.IntType(8; ctx)
831-
T_pint8 = LLVM.PointerType(T_int8)
832760

833761
state_intr = if haskey(functions(mod), "julia.gpu.state_getter")
834762
functions(mod)["julia.gpu.state_getter"]
835763
else
836-
LLVM.Function(mod, "julia.gpu.state_getter", LLVM.FunctionType(T_pint8))
764+
LLVM.Function(mod, "julia.gpu.state_getter", LLVM.FunctionType(T_state))
837765
end
838766
push!(function_attributes(state_intr), EnumAttribute("readnone", 0; ctx))
839767

840768
return state_intr
841769
end
842770

843-
# run-time equivalent (untyped)
844-
@inline kernel_state_pointer() = Base.llvmcall(("""
845-
declare i8* @julia.gpu.state_getter()
771+
# run-time equivalent
772+
function kernel_state_value(state)
773+
Context() do ctx
774+
T_state = convert(LLVMType, state; ctx)
775+
776+
# create function
777+
llvm_f, _ = create_function(T_state)
778+
mod = LLVM.parent(llvm_f)
779+
780+
# get intrinsic
781+
state_intr = kernel_state_intr(mod, T_state)
782+
783+
# generate IR
784+
Builder(ctx) do builder
785+
entry = BasicBlock(llvm_f, "entry"; ctx)
786+
position!(builder, entry)
787+
788+
val = call!(builder, state_intr, Value[], "state")
846789

847-
define i64 @entry() #0 {
848-
%ptls = call i8* @julia.gpu.state_getter()
849-
%ptr = ptrtoint i8* %ptls to i64
850-
ret i64 %ptr
851-
}
790+
ret!(builder, val)
791+
end
852792

853-
attributes #0 = { alwaysinline readnone }""", "entry"),
854-
Ptr{Cvoid}, Tuple{})
793+
call_function(llvm_f, state)
794+
end
795+
end

src/optim.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ function lower_gc_frame!(fun::LLVM.Function)
262262
call = user(use)::LLVM.CallInst
263263

264264
# decode the call
265-
ops = operands(call)
265+
ops = arguments(call)
266266
sz = ops[2]
267267

268268
# replace with PTX alloc_obj

src/rtlib.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,10 @@ function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[
4949
args = Value[args...]
5050
if state !== Nothing
5151
T_state = convert(LLVMType, state; ctx)
52-
T_ptr_state = LLVM.PointerType(T_state)
5352

54-
state_intr = kernel_state_intr(mod)
55-
untyped_state = call!(builder, state_intr, Value[], "state")
56-
typed_state = bitcast!(builder, untyped_state, T_ptr_state)
57-
pushfirst!(args, typed_state)
53+
state_intr = kernel_state_intr(mod, T_state)
54+
state_val = call!(builder, state_intr, Value[], "state")
55+
pushfirst!(args, state_val)
5856
end
5957

6058
# runtime functions are written in Julia, while we're calling from LLVM,

src/runtime.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ function Base.convert(::Type{LLVM.FunctionType}, rt::RuntimeMethodInstance;
4444
# if we're running post-optimization, prepend the kernel state to the argument list
4545
if state !== Nothing
4646
T_state = convert(LLVMType, state; ctx)
47-
T_ptr_state = LLVM.PointerType(T_state)
48-
pushfirst!(types, T_ptr_state)
47+
pushfirst!(types, T_state)
4948
end
5049

5150
return_type = if rt.llvm_return_type === nothing

src/validation.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst)
141141
fn = LLVM.name(dest)
142142

143143
# some special handling for runtime functions that we don't implement
144-
if fn == "jl_get_binding_or_error"
144+
if fn == "jl_get_binding_or_error" || fn == "ijl_get_binding_or_error"
145145
try
146-
m, sym, _ = operands(inst)
146+
m, sym = arguments(inst)
147147
sym = first(operands(sym::ConstantExpr))::ConstantInt
148148
sym = convert(Int, sym)
149149
sym = Ptr{Cvoid}(sym)
@@ -153,9 +153,9 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst)
153153
@debug "Decoding arguments to jl_get_binding_or_error failed" inst bb=LLVM.parent(inst)
154154
push!(errors, (DELAYED_BINDING, bt, nothing))
155155
end
156-
elseif fn == "jl_invoke"
156+
elseif fn == "jl_invoke" || fn == "ijl_invoke"
157157
try
158-
f, args, nargs, meth = operands(inst)
158+
f, args, nargs, meth = arguments(inst)
159159
meth = first(operands(meth::ConstantExpr))::ConstantInt
160160
meth = convert(Int, meth)
161161
meth = Ptr{Cvoid}(meth)
@@ -165,9 +165,9 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst)
165165
@debug "Decoding arguments to jl_invoke failed" inst bb=LLVM.parent(inst)
166166
push!(errors, (DYNAMIC_CALL, bt, nothing))
167167
end
168-
elseif fn == "jl_apply_generic"
168+
elseif fn == "jl_apply_generic" || fn == "ijl_apply_generic"
169169
try
170-
f, args, nargs, _ = operands(inst)
170+
f, args, nargs = arguments(inst)
171171
f = first(operands(f))::ConstantInt # get rid of inttoptr
172172
f = convert(Int, f)
173173
f = Ptr{Cvoid}(f)
@@ -201,7 +201,7 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst)
201201

202202
elseif isa(dest, ConstantExpr)
203203
# detect calls to literal pointers
204-
if occursin("inttoptr", string(dest))
204+
if opcode(dest) == LLVM.API.LLVMIntToPtr
205205
# extract the literal pointer
206206
ptr_arg = first(operands(dest))
207207
@compiler_assert isa(ptr_arg, ConstantInt) job

0 commit comments

Comments
 (0)