Skip to content

Commit 0eb6e56

Browse files
committed
Zero-init host and runtime Refs
1 parent 2247de6 commit 0eb6e56

File tree

4 files changed

+92
-30
lines changed

4 files changed

+92
-30
lines changed

src/BPFnative.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ module RT
6868
import ..API
6969
using ..LLVM
7070
using ..LLVM.Interop
71+
import Core: LLVMPtr
7172
include("runtime/maps_core.jl")
7273
include("runtime/bpfcall.jl")
7374
include("runtime/maps.jl")

src/host/maps.jl

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,25 @@ struct map_access_elem_attr
4646
flags::UInt64
4747
end
4848

49+
memset!(ptr::Ptr{T}) where T =
50+
ccall(:memset, Cvoid,
51+
(Ptr{T}, UInt8, UInt64),
52+
ptr, UInt8(0), sizeof(T))
53+
"Creates a Ref{T} that's been zero-initialized before being stored. Necessary
54+
to ensure that struct padding bytes are zeroed."
55+
function ZeroInitRef(T, val; set=true)
56+
ref = Ref{T}()
57+
memset!(Base.unsafe_convert(Ptr{T}, ref))
58+
if set
59+
ref[] = val
60+
end
61+
ref
62+
end
63+
ZeroInitRef(val::T) where T = ZeroInitRef(T, val)
64+
ZeroInitRef(::Type{T}) where T = ZeroInitRef(T, nothing; set=false)
65+
4966
function Base.getindex(map::AbstractHashMap{K,V}, idx) where {K,V}
50-
key = Ref{K}(idx)
67+
key = ZeroInitRef(K, idx)
5168
value = Ref{V}()
5269
key_ptr = Base.unsafe_convert(Ptr{K}, key)
5370
value_ptr = Base.unsafe_convert(Ptr{V}, value)
@@ -62,7 +79,7 @@ function Base.getindex(map::AbstractHashMap{K,V}, idx) where {K,V}
6279
value[]
6380
end
6481
function Base.getindex(map::AbstractArrayMap{K,V}, idx) where {K,V}
65-
key = Ref{K}(idx-1)
82+
key = ZeroInitRef(K, idx-1)
6683
value = Ref{V}()
6784
key_ptr = Base.unsafe_convert(Ptr{K}, key)
6885
value_ptr = Base.unsafe_convert(Ptr{V}, value)
@@ -78,8 +95,8 @@ function Base.getindex(map::AbstractArrayMap{K,V}, idx) where {K,V}
7895
end
7996

8097
function Base.setindex!(map::AbstractHashMap{K,V}, value::U, idx) where {K,V,U}
81-
key_ref = Ref{K}(convert(K,idx))
82-
value_ref = Ref{V}(convert(V,value))
98+
key_ref = ZeroInitRef(convert(K,idx))
99+
value_ref = ZeroInitRef(convert(V,value))
83100
key_ptr = Base.unsafe_convert(Ptr{K}, key_ref)
84101
value_ptr = Base.unsafe_convert(Ptr{V}, value_ref)
85102
attr = Ref(map_access_elem_attr(map.fd,
@@ -93,8 +110,8 @@ function Base.setindex!(map::AbstractHashMap{K,V}, value::U, idx) where {K,V,U}
93110
value
94111
end
95112
function Base.setindex!(map::AbstractArrayMap{K,V}, value::U, idx) where {K,V,U}
96-
key_ref = Ref{K}(convert(K,idx-1))
97-
value_ref = Ref{V}(convert(V,value))
113+
key_ref = ZeroInitRef(convert(K,idx-1))
114+
value_ref = ZeroInitRef(convert(V,value))
98115
key_ptr = Base.unsafe_convert(Ptr{K}, key_ref)
99116
value_ptr = Base.unsafe_convert(Ptr{V}, value_ref)
100117
attr = Ref(map_access_elem_attr(map.fd,
@@ -109,7 +126,7 @@ function Base.setindex!(map::AbstractArrayMap{K,V}, value::U, idx) where {K,V,U}
109126
end
110127

111128
function Base.delete!(map::AbstractHashMap{K,V}, idx) where {K,V}
112-
key = Ref{K}(idx)
129+
key = ZeroInitRef(K, idx)
113130
key_ptr = Base.unsafe_convert(Ptr{K}, key)
114131
attr = Ref(map_access_elem_attr(map.fd,
115132
key_ptr,
@@ -122,7 +139,7 @@ function Base.delete!(map::AbstractHashMap{K,V}, idx) where {K,V}
122139
end
123140

124141
function Base.haskey(map::HostMap{K,V}, idx) where {K,V}
125-
key = Ref{K}(idx)
142+
key = ZeroInitRef(K, idx)
126143
value = Ref{V}()
127144
key_ptr = Base.unsafe_convert(Ptr{K}, key)
128145
value_ptr = Base.unsafe_convert(Ptr{V}, value)
@@ -136,7 +153,7 @@ function Base.haskey(map::HostMap{K,V}, idx) where {K,V}
136153
end
137154

138155
function nextkey(map::HostMap{K,V}, idx) where {K,V}
139-
key = Ref{K}(idx)
156+
key = ZeroInitRef(K, idx)
140157
nkey = Ref{K}()
141158
key_ptr = Base.unsafe_convert(Ptr{K}, key)
142159
nkey_ptr = Base.unsafe_convert(Ptr{K}, nkey)

src/runtime/maps.jl

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
11
function map_lookup_elem(map::RTMap{Name,MT,K,V,ME,F}, key::K) where {Name,MT,K,V,ME,F}
2-
keyref = Ref{K}(key)
2+
keyref = ZeroInitRef(K, key)
33
GC.@preserve keyref begin
4-
_map_lookup_elem(map, Base.unsafe_convert(Ptr{K}, keyref))
4+
keyref_ptr = Base.unsafe_convert(Ptr{K}, keyref)
5+
_map_lookup_elem(map, keyref_ptr)
56
end
67
end
7-
function map_update_elem(map::RTMap{Name,MT,K,V,ME,F}, key::K, value::V, flags::UInt64) where {Name,MT,K,V,ME,F}
8-
keyref = Ref{K}(key)
9-
valref = Ref{V}(value)
8+
@inline function map_update_elem(map::RTMap{Name,MT,K,V,ME,F}, key::K, value::V, flags::UInt64) where {Name,MT,K,V,ME,F}
9+
keyref = ZeroInitRef(K, key)
10+
valref = ZeroInitRef(V, value)
1011
GC.@preserve keyref valref begin
11-
_map_update_elem(map,
12-
Base.unsafe_convert(Ptr{K}, keyref),
13-
Base.unsafe_convert(Ptr{V}, valref),
14-
flags)
12+
keyref_ptr = Base.unsafe_convert(Ptr{K}, keyref)
13+
valref_ptr = Base.unsafe_convert(Ptr{V}, valref)
14+
_map_update_elem(map, keyref_ptr, valref_ptr, flags)
1515
end
1616
end
1717
function map_delete_elem(map::RTMap{Name,MT,K,V,ME,F}, key::K) where {Name,MT,K,V,ME,F}
18-
keyref = Ref{K}(key)
18+
keyref = ZeroInitRef(K, key)
1919
GC.@preserve keyref begin
20+
keyref_ptr = Base.unsafe_convert(Ptr{K}, keyref)
2021
_map_delete_elem(map, Base.unsafe_convert(Ptr{K}, keyref))
2122
end
2223
end
24+
25+
# TODO: Use bpfcall
2326
@generated function _map_lookup_elem(map::RTMap{Name,MT,K,V,ME,F}, key::Ptr{K}) where {Name,MT,K,V,ME,F}
2427
JuliaContext() do ctx
2528
T_keyp = LLVM.PointerType(convert(LLVMType, K, ctx))
@@ -125,15 +128,15 @@ end
125128
@inline Base.getindex(map::RTMap{Name,MT,K,V,ME,F}, idx) where {Name,MT,K,V,ME,F} =
126129
getindex(map, bpfconvert(K, idx))
127130
Base.getindex(map::RTMap, ::Nothing) = nothing
128-
function Base.getindex(map::AbstractHashMap{Name,MT,K,V,ME,F}, idx::K) where {Name,MT,K,V,ME,F}
131+
@inline function Base.getindex(map::AbstractHashMap{Name,MT,K,V,ME,F}, idx::K) where {Name,MT,K,V,ME,F}
129132
ptr = map_lookup_elem(map, idx)
130133
if reinterpret(UInt64, ptr) > 0
131134
return unsafe_load(ptr)
132135
else
133136
return nothing
134137
end
135138
end
136-
function Base.getindex(map::AbstractArrayMap{Name,MT,K,V,ME,F}, idx::K) where {Name,MT,K,V,ME,F}
139+
@inline function Base.getindex(map::AbstractArrayMap{Name,MT,K,V,ME,F}, idx::K) where {Name,MT,K,V,ME,F}
137140
if idx > 0
138141
ptr = map_lookup_elem(map, idx-K(1))
139142
if reinterpret(UInt64, ptr) > 0
@@ -148,14 +151,14 @@ end
148151

149152
@inline Base.setindex!(map::RTMap{Name,MT,K,V,ME,F}, value, idx) where {Name,MT,K,V,ME,F} =
150153
setindex!(map, bpfconvert(V, value), bpfconvert(K, idx))
151-
Base.setindex!(map::RTMap, ::Nothing, idx) = nothing
152-
Base.setindex!(map::RTMap, value, ::Nothing) = nothing
153-
Base.setindex!(map::RTMap, ::Nothing, ::Nothing) = nothing
154-
function Base.setindex!(map::AbstractHashMap{Name,MT,K,V,ME,F}, value::V, idx::K) where {Name,MT,K,V,ME,F}
154+
@inline Base.setindex!(map::RTMap, ::Nothing, idx) = nothing
155+
@inline Base.setindex!(map::RTMap, value, ::Nothing) = nothing
156+
@inline Base.setindex!(map::RTMap, ::Nothing, ::Nothing) = nothing
157+
@inline function Base.setindex!(map::AbstractHashMap{Name,MT,K,V,ME,F}, value::V, idx::K) where {Name,MT,K,V,ME,F}
155158
map_update_elem(map, idx, value, UInt64(0))
156159
value
157160
end
158-
function Base.setindex!(map::AbstractArrayMap{Name,MT,K,V,ME,F}, value::V, idx::K) where {Name,MT,K,V,ME,F}
161+
@inline function Base.setindex!(map::AbstractArrayMap{Name,MT,K,V,ME,F}, value::V, idx::K) where {Name,MT,K,V,ME,F}
159162
if idx > 0
160163
map_update_elem(map, idx-K(1), value, UInt64(0))
161164
end
@@ -176,18 +179,18 @@ end
176179
map
177180
end
178181

179-
Base.haskey(map::AbstractHashMap{Name,MT,K,V,ME,F}, idx) where {Name,MT,K,V,ME,F} =
182+
@inline Base.haskey(map::AbstractHashMap{Name,MT,K,V,ME,F}, idx) where {Name,MT,K,V,ME,F} =
180183
map[bpfconvert(K, idx)] !== nothing
181-
Base.haskey(map::RTMap, ::Nothing) = false
182-
function Base.haskey(map::AbstractArrayMap{Name,MT,K,V,ME,F}, idx) where {Name,MT,K,V,ME,F}
184+
@inline Base.haskey(map::RTMap, ::Nothing) = false
185+
@inline function Base.haskey(map::AbstractArrayMap{Name,MT,K,V,ME,F}, idx) where {Name,MT,K,V,ME,F}
183186
if idx > 0
184187
map[bpfconvert(K, idx)-K(1)] !== nothing
185188
else
186189
false
187190
end
188191
end
189192

190-
function Base.get(map::RTMap{Name,MT,K,V,ME,F}, k::K, v::V) where {Name,MT,K,V,ME,F}
193+
@inline function Base.get(map::RTMap{Name,MT,K,V,ME,F}, k::K, v::V) where {Name,MT,K,V,ME,F}
191194
map_v = map[k]
192195
if map_v !== nothing
193196
return map_v

src/runtime/maps_core.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,44 @@ function _genmap!(mod::LLVM.Module, ::Type{<:RTMap{Name,MT,K,V,ME,F}}, ctx) wher
2727
return gv
2828
end
2929

30+
# From AMDGPU.jl/src/device/gcn/memory_static.jl
31+
@inline function _memset!(builder, ctx, mod, dest, value, len, volatile)
32+
T_nothing = LLVM.VoidType(ctx)
33+
T_dest = llvmtype(dest)
34+
T_int8 = convert(LLVMType, UInt8, ctx)
35+
T_int64 = convert(LLVMType, UInt64, ctx)
36+
T_int1 = LLVM.Int1Type(ctx)
37+
38+
T_intr = LLVM.FunctionType(T_nothing, [T_dest, T_int8, T_int64, T_int1])
39+
intr = LLVM.Function(mod, "llvm.memset.p$(Int(addrspace(T_dest)))i8.i64", T_intr)
40+
call!(builder, intr, [dest, value, len, volatile])
41+
end
42+
@inline @generated function memset!(dest_ptr::LLVMPtr{UInt8,DestAS}, value::UInt8, len::LT) where {DestAS,LT<:Union{Int64,UInt64}}
43+
JuliaContext() do ctx
44+
T_nothing = LLVM.VoidType(ctx)
45+
T_pint8_dest = convert(LLVMType, dest_ptr, ctx)
46+
T_int8 = convert(LLVMType, UInt8, ctx)
47+
T_int64 = convert(LLVMType, UInt64, ctx)
48+
T_int1 = LLVM.Int1Type(ctx)
49+
50+
llvm_f, _ = create_function(T_nothing, [T_pint8_dest, T_int8, T_int64])
51+
mod = LLVM.parent(llvm_f)
52+
Builder(ctx) do builder
53+
entry = BasicBlock(llvm_f, "entry", ctx)
54+
position!(builder, entry)
55+
56+
_memset!(builder, ctx, mod, parameters(llvm_f)[1], parameters(llvm_f)[2], parameters(llvm_f)[3], ConstantInt(T_int1, 0))
57+
ret!(builder)
58+
end
59+
call_function(llvm_f, Nothing, Tuple{LLVMPtr{UInt8,DestAS},UInt8,LT}, :((dest_ptr, value, len)))
60+
end
61+
end
62+
@inline memset!(dest_ptr::LLVMPtr{T,DestAS}, value::UInt8, len::Integer) where {T,DestAS} =
63+
memset!(reinterpret(LLVMPtr{UInt8,DestAS}, dest_ptr), value, UInt64(len))
64+
@inline function ZeroInitRef(T, val)
65+
ref = Ref{T}()
66+
ref_llptr = reinterpret(LLVMPtr{T,0}, Base.unsafe_convert(Ptr{T}, ref))
67+
memset!(ref_llptr, UInt8(0), sizeof(T))
68+
ref[] = val
69+
ref
70+
end

0 commit comments

Comments
 (0)