Skip to content

Commit 68311b2

Browse files
authored
Merge pull request #319 from christiangnrd/minor-cleanup
Minor cleanup
2 parents bb33fa5 + 27e5e19 commit 68311b2

File tree

9 files changed

+43
-46
lines changed

9 files changed

+43
-46
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ Manifest.toml
44
*.trace
55
wip.*
66
dev
7+
.vscode

Project.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
2727
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2828
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2929

30+
[weakdeps]
31+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
32+
33+
[extensions]
34+
SpecialFunctionsExt = "SpecialFunctions"
35+
3036
[compat]
3137
Adapt = "4"
3238
Artifacts = "1"
@@ -47,11 +53,5 @@ SHA = "0.7"
4753
StaticArrays = "1"
4854
julia = "1.8"
4955

50-
[extensions]
51-
SpecialFunctionsExt = "SpecialFunctions"
52-
5356
[extras]
5457
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
55-
56-
[weakdeps]
57-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

lib/mtl/buffer.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer;
3333
return MTLBuffer(ptr)
3434
end
3535

36-
function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer, ptr::Ptr;
36+
function MTLBuffer(dev::MTLDevice, bytesize::Integer, ptr::Ptr;
3737
storage=Managed, hazard_tracking=DefaultTracking,
3838
cache_mode=DefaultCPUCache)
3939
storage == Private && error("Can't create a Private copy-allocated buffer.")
4040
opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode
4141

42-
@assert 0 < bytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap
42+
@assert 0 < bytesize <= dev.maxBufferLength
4343
ptr = alloc_buffer(dev, bytesize, opts, ptr)
4444

4545
return MTLBuffer(ptr)
@@ -58,10 +58,6 @@ alloc_buffer(dev::MTLDevice, bytesize, opts, ptr::Ptr) =
5858
alloc_buffer(dev::MTLHeap, bytesize, opts) =
5959
@objc [dev::id{MTLHeap} newBufferWithLength:bytesize::NSUInteger
6060
options:opts::MTLResourceOptions]::id{MTLBuffer}
61-
alloc_buffer(dev::MTLHeap, bytesize, opts, ptr::Ptr) =
62-
@objc [dev::id{MTLHeap} newBufferWithBytes:ptr::Ptr{Cvoid}
63-
length:bytesize::NSUInteger
64-
options:opts::MTLResourceOptions]::id{MTLBuffer}
6561

6662
"""
6763
DidModifyRange!(buf::MTLBuffer, range::UnitRange)

src/array.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ Base.sizeof(x::MtlArray) = Base.elsize(x) * length(x)
169169

170170
@inline function Base.pointer(x::MtlArray{T}, i::Integer=1; storage=Private) where {T}
171171
PT = if storage == Private
172-
MtlPointer{T}
172+
MtlPtr{T}
173173
elseif storage == Shared || storage == Managed
174174
Ptr{T}
175175
else
@@ -179,9 +179,9 @@ Base.sizeof(x::MtlArray) = Base.elsize(x) * length(x)
179179
end
180180

181181

182-
function Base.unsafe_convert(::Type{MtlPointer{T}}, x::MtlArray) where {T}
182+
function Base.unsafe_convert(::Type{MtlPtr{T}}, x::MtlArray) where {T}
183183
buf = x.data[]
184-
MtlPointer{T}(buf, x.offset*Base.elsize(x))
184+
MtlPtr{T}(buf, x.offset*Base.elsize(x))
185185
end
186186

187187
function Base.unsafe_convert(::Type{Ptr{S}}, x::MtlArray{T}) where {S, T}
@@ -479,15 +479,15 @@ Base.unsafe_convert(::Type{MTL.MTLBuffer}, A::PermutedDimsArray) =
479479
## unsafe_wrap
480480

481481
function Base.unsafe_wrap(::Type{<:Array}, arr::MtlArray{T,N}, dims=size(arr); own=false) where {T,N}
482-
return unsafe_wrap(Array{T,N}, arr.data[], dims; own=own)
482+
return unsafe_wrap(Array{T,N}, arr.data[], dims; own)
483483
end
484484

485485
function Base.unsafe_wrap(t::Type{<:Array{T}}, buf::MTLBuffer, dims; own=false) where T
486486
ptr = convert(Ptr{T}, buf)
487487
return unsafe_wrap(t, ptr, dims; own)
488488
end
489489

490-
function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPointer{T}, dims; own=false) where T
490+
function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPtr{T}, dims; own=false) where T
491491
return unsafe_wrap(t, convert(Ptr{T}, ptr), dims; own)
492492
end
493493

@@ -513,7 +513,7 @@ function Base.resize!(A::MtlVector{T}, n::Integer) where T
513513
# replace the data with a new one. this 'unshares' the array.
514514
# as a result, we can safely support resizing unowned buffers.
515515
buf = alloc(device(A), bufsize; storage=storagemode(A))
516-
ptr = MtlPointer{T}(buf)
516+
ptr = MtlPtr{T}(buf)
517517
m = min(length(A), n)
518518
if m > 0
519519
unsafe_copyto!(device(A), ptr, pointer(A), m)

src/compiler/execution.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ function Adapt.adapt_storage(to::Adaptor, buf::MTLBuffer)
112112
end
113113
reinterpret(Core.LLVMPtr{Nothing,AS.Device}, buf.gpuAddress)
114114
end
115-
function Adapt.adapt_storage(to::Adaptor, ptr::MtlPointer{T}) where {T}
115+
function Adapt.adapt_storage(to::Adaptor, ptr::MtlPtr{T}) where {T}
116116
reinterpret(Core.LLVMPtr{T,AS.Device}, adapt(to, ptr.buffer)) + ptr.offset
117117
end
118118

@@ -209,7 +209,7 @@ const _kernel_instances = Dict{UInt, Any}()
209209
end
210210

211211
# the arguments passed into this function have not been `mtlconvert`ed, because we need
212-
# to retain the top-level MTLBuffer and MtlPointer objects. eager conversion of nested
212+
# to retain the top-level MTLBuffer and MtlPtr objects. eager conversion of nested
213213
# such objects to LLVMPtr seems fine, somehow.
214214
# TODO: can we just convert everything eagerly and support top-level LLVMPtrs?
215215

@@ -219,7 +219,7 @@ const _kernel_instances = Dict{UInt, Any}()
219219
if argtyp <: MTLBuffer
220220
# top-level buffers are passed as a pointer-valued argument
221221
push!(ex.args, :(set_buffer!(cce, $argex, 0, $idx)))
222-
elseif argtyp <: MtlPointer
222+
elseif argtyp <: MtlPtr
223223
# the same as a buffer, but with an offset
224224
push!(ex.args, :(set_buffer!(cce, $argex.buffer, $argex.offset, $idx)))
225225
elseif isghosttype(argtyp) || Core.Compiler.isconstType(argtyp)
@@ -253,7 +253,7 @@ end
253253
end
254254

255255
# pass by reference, in an argument buffer
256-
argument_buffer = alloc(kernel.pipeline.device, sizeof(argtyp), storage=Shared)
256+
argument_buffer = alloc(kernel.pipeline.device, sizeof(argtyp); storage=Shared)
257257
argument_buffer.label = "MTLBuffer for kernel argument"
258258
unsafe_store!(convert(Ptr{argtyp}, argument_buffer), arg)
259259
return argument_buffer

src/memory.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,39 @@
77
# we cannot take a MTLBuffer's handle and work with that as it were a pointer to memory.
88
# instead, the Metal APIs always take the original handle and an offset parameter.
99

10-
struct MtlPointer{T}
10+
struct MtlPtr{T}
1111
buffer::MTLBuffer
1212
offset::UInt # in bytes
1313

14-
function MtlPointer{T}(buffer::MTLBuffer, offset=0) where {T}
14+
function MtlPtr{T}(buffer::MTLBuffer, offset=0) where {T}
1515
new(buffer, offset)
1616
end
1717
end
1818

19-
Base.eltype(::Type{<:MtlPointer{T}}) where {T} = T
19+
Base.eltype(::Type{<:MtlPtr{T}}) where {T} = T
2020

2121
# limited arithmetic
22-
Base.:(+)(x::MtlPointer{T}, y::Integer) where {T} = MtlPointer{T}(x.buffer, x.offset+y)
23-
Base.:(-)(x::MtlPointer{T}, y::Integer) where {T} = MtlPointer{T}(x.buffer, x.offset-y)
24-
Base.:(+)(x::Integer, y::MtlPointer{T}) where {T} = MtlPointer{T}(x.buffer, y+x.offset)
22+
Base.:(+)(x::MtlPtr{T}, y::Integer) where {T} = MtlPtr{T}(x.buffer, x.offset+y)
23+
Base.:(-)(x::MtlPtr{T}, y::Integer) where {T} = MtlPtr{T}(x.buffer, x.offset-y)
24+
Base.:(+)(x::Integer, y::MtlPtr{T}) where {T} = MtlPtr{T}(x.buffer, y+x.offset)
2525

26-
Base.convert(::Type{Ptr{T}}, ptr::MtlPointer) where {T} =
26+
Base.convert(::Type{Ptr{T}}, ptr::MtlPtr) where {T} =
2727
convert(Ptr{T}, ptr.buffer) + ptr.offset
2828

2929

3030
## operations
3131

3232
# CPU -> GPU
33-
function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPointer{T}, src::Ptr{T}, N::Integer;
33+
function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPtr{T}, src::Ptr{T}, N::Integer;
3434
queue::MTLCommandQueue=global_queue(dev), async::Bool=false) where T
3535
storage_type = dst.buffer.storageMode
3636
if storage_type == MTL.MTLStorageModePrivate
3737
# stage through a shared buffer
38-
# shared = alloc(dev, N*sizeof(T), src, storage=Shared)
38+
# shared = alloc(dev, N*sizeof(T), src; storage=Shared)
3939
# unsafe_copyto!(dev, dst, pointer(shared), N; queue, async=false)
4040
# free(shared)
41-
tmp_buf = alloc(dev, N*sizeof(T), src, storage=Shared) #CPU -> GPU (Shared)
42-
unsafe_copyto!(dev, MtlPointer{T}(dst.buffer, dst.offset), MtlPointer{T}(tmp_buf, 0), N; queue, async=false) # GPU (Shared) -> GPU (Private)
41+
tmp_buf = alloc(dev, N*sizeof(T), src; storage=Shared) #CPU -> GPU (Shared)
42+
unsafe_copyto!(dev, MtlPtr{T}(dst.buffer, dst.offset), MtlPtr{T}(tmp_buf, 0), N; queue, async=false) # GPU (Shared) -> GPU (Private)
4343
free(tmp_buf)
4444
elseif storage_type == MTL.MTLStorageModeShared
4545
unsafe_copyto!(convert(Ptr{T}, dst), src, N)
@@ -51,13 +51,13 @@ function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPointer{T}, src::Ptr{T}, N:
5151
end
5252

5353
# GPU -> CPU
54-
function Base.unsafe_copyto!(dev::MTLDevice, dst::Ptr{T}, src::MtlPointer{T}, N::Integer;
54+
function Base.unsafe_copyto!(dev::MTLDevice, dst::Ptr{T}, src::MtlPtr{T}, N::Integer;
5555
queue::MTLCommandQueue=global_queue(dev), async::Bool=false) where T
5656
storage_type = src.buffer.storageMode
5757
if storage_type == MTL.MTLStorageModePrivate
5858
# stage through a shared buffer
59-
shared = alloc(dev, N*sizeof(T), storage=Shared)
60-
unsafe_copyto!(dev, MtlPointer{T}(shared, 0), MtlPointer{T}(src.buffer, src.offset), N; queue, async=false)
59+
shared = alloc(dev, N*sizeof(T); storage=Shared)
60+
unsafe_copyto!(dev, MtlPtr{T}(shared, 0), MtlPtr{T}(src.buffer, src.offset), N; queue, async=false)
6161
unsafe_copyto!(dst, convert(Ptr{T}, shared), N)
6262
free(shared)
6363
elseif storage_type == MTL.MTLStorageModeShared
@@ -69,8 +69,8 @@ function Base.unsafe_copyto!(dev::MTLDevice, dst::Ptr{T}, src::MtlPointer{T}, N:
6969
end
7070

7171
# GPU -> GPU
72-
@autoreleasepool function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPointer{T},
73-
src::MtlPointer{T}, N::Integer;
72+
@autoreleasepool function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPtr{T},
73+
src::MtlPtr{T}, N::Integer;
7474
queue::MTLCommandQueue=global_queue(dev),
7575
async::Bool=false) where T
7676
cmdbuf = MTLCommandBuffer(queue)
@@ -81,7 +81,7 @@ end
8181
async || wait_completed(cmdbuf)
8282
end
8383

84-
@autoreleasepool function unsafe_fill!(dev::MTLDevice, ptr::MtlPointer{T},
84+
@autoreleasepool function unsafe_fill!(dev::MTLDevice, ptr::MtlPtr{T},
8585
value::Union{UInt8,Int8}, N::Integer) where T
8686
cmdbuf = MTLCommandBuffer(global_queue(dev))
8787
MTLBlitCommandEncoder(cmdbuf) do enc

src/pool.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ The storage kwarg controls where the buffer is stored. Possible values are:
5151
Note that `Private` buffers can't be directly accessed from the CPU, therefore you cannot
5252
use this option if you pass a ptr to initialize the memory.
5353
"""
54-
function alloc(dev::Union{MTLDevice,MTLHeap}, sz::Integer, args...; storage, kwargs...)
54+
function alloc(dev::Union{MTLDevice,MTLHeap}, sz::Integer, args...; kwargs...)
5555
@signpost_event log=log_array() "Allocate" "Size=$(Base.format_bytes(sz))"
5656

5757
time = Base.@elapsed begin
58-
buf = @autoreleasepool MTLBuffer(dev, sz, args...; storage, kwargs...)
58+
buf = @autoreleasepool MTLBuffer(dev, sz, args...; kwargs...)
5959
end
6060

6161
Base.@atomic alloc_stats.alloc_count + 1

test/array.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ check_storagemode(arr, smode) = Metal.storagemode(arr) == smode
7272
N = length(dim)
7373

7474
# mtl
75-
let arr = mtl(rand(2,2), storage= SM)
75+
let arr = mtl(rand(2,2); storage= SM)
7676
@test check_storagemode(arr, SM)
7777
end
7878

@@ -134,7 +134,7 @@ check_storagemode(arr, smode) = Metal.storagemode(arr) == smode
134134
end
135135

136136
let b = rand(Float32, 10)
137-
arr_mtl = mtl(b, storage=Private)
137+
arr_mtl = mtl(b; storage=Private)
138138
@test_throws ErrorException arr_mtl[1]
139139
@test Metal.@allowscalar arr_mtl[1] == b[1]
140140
end
@@ -145,7 +145,7 @@ check_storagemode(arr, smode) = Metal.storagemode(arr) == smode
145145
end
146146

147147
let b = rand(Float32, 10)
148-
arr_mtl = mtl(b, storage=Shared)
148+
arr_mtl = mtl(b; storage=Shared)
149149
@test arr_mtl[1] == b[1]
150150
end
151151
elseif SM == Metal.Managed
@@ -155,7 +155,7 @@ check_storagemode(arr, smode) = Metal.storagemode(arr) == smode
155155
end
156156

157157
let b = rand(Float32, 10)
158-
arr_mtl = mtl(b, storage=Managed)
158+
arr_mtl = mtl(b; storage=Managed)
159159
@test arr_mtl[1] == b[1]
160160
end
161161
end

test/metal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ end
454454
arr = Metal.zeros(T, 4)
455455

456456
buf = Base.unsafe_convert(MTL.MTLBuffer, arr)
457-
Metal.unsafe_fill!(current_device(), Metal.MtlPointer{T}(buf, 0), T(val), 4)
457+
Metal.unsafe_fill!(current_device(), Metal.MtlPtr{T}(buf, 0), T(val), 4)
458458

459459
@test all(Array(arr) .== val)
460460
end

0 commit comments

Comments
 (0)