Skip to content

Commit d74ddc5

Browse files
committed
[Metal] Emit global dynamic memory
1 parent c428508 commit d74ddc5

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

src/metal.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
168168

169169
add_argument_metadata!(job, mod, entry)
170170

171+
add_globals_metadata!(job, mod, entry)
172+
171173
add_module_metadata!(job, mod)
172174
end
173175

@@ -550,6 +552,96 @@ function argument_type_name(typ)
550552
end
551553
end
552554

555+
# global metadata generation
556+
#
557+
# module metadata is used to identify global buffers that are used as kernel arguments.
558+
function add_globals_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
559+
entry::LLVM.Function)
560+
entry_ft = function_type(entry)
561+
562+
## argument info
563+
arg_infos = Metadata[]
564+
565+
566+
# Iterate through arguments and create metadata for them
567+
globs = globals(mod)
568+
@show globs
569+
i = 1
570+
for gv in globs
571+
@show gv
572+
gv_typ = global_value_type(gv)
573+
(isconstant(gv) && addrspace(gv_typ) == 3) || continue
574+
# if job.config.optimize
575+
# @assert parameters(entry_ft)[arg.idx] isa LLVM.PointerType
576+
# else
577+
# parameters(entry_ft)[arg.idx] isa LLVM.PointerType || continue
578+
# end
579+
580+
# # NOTE: we emit the bare minimum of argument metadata to support
581+
# # bindless argument encoding. Actually using the argument encoder
582+
# # APIs (deprecated in Metal 3) turned out too difficult, given the
583+
# # undocumented nature of the argument metadata, and the complex
584+
# # arguments we encounter with typical Julia kernels.
585+
global_infos = Metadata[]
586+
587+
push!(global_infos, MDString("air.global_binding"))
588+
push!(global_infos, Metadata(gv))
589+
590+
md = Metadata[]
591+
592+
# argument index
593+
push!(md, Metadata(ConstantInt(Int32(-1))))
594+
595+
push!(md, MDString("air.buffer"))
596+
597+
push!(md, MDString("air.location_index"))
598+
push!(md, Metadata(ConstantInt(Int32(i-1))))
599+
600+
# XXX: unknown
601+
push!(md, Metadata(ConstantInt(Int32(1))))
602+
603+
push!(md, MDString("air.read_write")) # TODO: Check for const array
604+
605+
push!(md, MDString("air.address_space"))
606+
push!(md, Metadata(ConstantInt(Int32(addrspace(global_value_type(gv))))))
607+
608+
val_type = global_value_type(gv)
609+
# val_type = if value_type(gv) <: Core.LLVMPtr
610+
# arg.typ.parameters[1]
611+
# else
612+
# arg.typ
613+
# end
614+
615+
@show gv_typ
616+
@show isconstant(gv)
617+
# @show isconstant(gv_typ)
618+
# @show Int32(alignment(gv))
619+
620+
push!(md, MDString("air.arg_type_size"))
621+
push!(md, Metadata(ConstantInt(Int32(4))))
622+
623+
push!(md, MDString("air.arg_type_align_size"))
624+
push!(md, Metadata(ConstantInt(Int32(alignment(gv)))))
625+
626+
push!(md, MDString("air.arg_type_name"))
627+
# push!(md, MDString(repr(arg.typ)))
628+
629+
push!(md, MDString("air.arg_name"))
630+
push!(md, MDString(String(LLVM.name(gv))))
631+
632+
push!(arg_infos, MDNode(md))
633+
634+
i += 1
635+
end
636+
637+
println()
638+
arg_infos = MDNode(arg_infos)
639+
640+
push!(metadata(mod)["air.global_bindings"], arg_infos)
641+
642+
return
643+
end
644+
553645
# argument metadata generation
554646
#
555647
# module metadata is used to identify buffers that are passed as kernel arguments.

0 commit comments

Comments
 (0)