Skip to content

Commit e5f62b3

Browse files
committed
Rename MtlPointer to MtlPtr
For consistency with `Ptr` and `CuPtr`
1 parent 23c3f4f commit e5f62b3

File tree

4 files changed

+23
-23
lines changed

4 files changed

+23
-23
lines changed

src/array.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ Base.sizeof(x::MtlArray) = Base.elsize(x) * length(x)
169169

170170
@inline function Base.pointer(x::MtlArray{T}, i::Integer=1; storage=Private) where {T}
171171
PT = if storage == Private
172-
MtlPointer{T}
172+
MtlPtr{T}
173173
elseif storage == Shared || storage == Managed
174174
Ptr{T}
175175
else
@@ -179,9 +179,9 @@ Base.sizeof(x::MtlArray) = Base.elsize(x) * length(x)
179179
end
180180

181181

182-
function Base.unsafe_convert(::Type{MtlPointer{T}}, x::MtlArray) where {T}
182+
function Base.unsafe_convert(::Type{MtlPtr{T}}, x::MtlArray) where {T}
183183
buf = x.data[]
184-
MtlPointer{T}(buf, x.offset*Base.elsize(x))
184+
MtlPtr{T}(buf, x.offset*Base.elsize(x))
185185
end
186186

187187
function Base.unsafe_convert(::Type{Ptr{S}}, x::MtlArray{T}) where {S, T}
@@ -487,7 +487,7 @@ function Base.unsafe_wrap(t::Type{<:Array{T}}, buf::MTLBuffer, dims; own=false)
487487
return unsafe_wrap(t, ptr, dims; own)
488488
end
489489

490-
function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPointer{T}, dims; own=false) where T
490+
function Base.unsafe_wrap(t::Type{<:Array{T}}, ptr::MtlPtr{T}, dims; own=false) where T
491491
return unsafe_wrap(t, convert(Ptr{T}, ptr), dims; own)
492492
end
493493

@@ -513,7 +513,7 @@ function Base.resize!(A::MtlVector{T}, n::Integer) where T
513513
# replace the data with a new one. this 'unshares' the array.
514514
# as a result, we can safely support resizing unowned buffers.
515515
buf = alloc(device(A), bufsize; storage=storagemode(A))
516-
ptr = MtlPointer{T}(buf)
516+
ptr = MtlPtr{T}(buf)
517517
m = min(length(A), n)
518518
if m > 0
519519
unsafe_copyto!(device(A), ptr, pointer(A), m)

src/compiler/execution.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ function Adapt.adapt_storage(to::Adaptor, buf::MTLBuffer)
112112
end
113113
reinterpret(Core.LLVMPtr{Nothing,AS.Device}, buf.gpuAddress)
114114
end
115-
function Adapt.adapt_storage(to::Adaptor, ptr::MtlPointer{T}) where {T}
115+
function Adapt.adapt_storage(to::Adaptor, ptr::MtlPtr{T}) where {T}
116116
reinterpret(Core.LLVMPtr{T,AS.Device}, adapt(to, ptr.buffer)) + ptr.offset
117117
end
118118

@@ -209,7 +209,7 @@ const _kernel_instances = Dict{UInt, Any}()
209209
end
210210

211211
# the arguments passed into this function have not been `mtlconvert`ed, because we need
212-
# to retain the top-level MTLBuffer and MtlPointer objects. eager conversion of nested
212+
# to retain the top-level MTLBuffer and MtlPtr objects. eager conversion of nested
213213
# such objects to LLVMPtr seems fine, somehow.
214214
# TODO: can we just convert everything eagerly and support top-level LLVMPtrs?
215215

@@ -219,7 +219,7 @@ const _kernel_instances = Dict{UInt, Any}()
219219
if argtyp <: MTLBuffer
220220
# top-level buffers are passed as a pointer-valued argument
221221
push!(ex.args, :(set_buffer!(cce, $argex, 0, $idx)))
222-
elseif argtyp <: MtlPointer
222+
elseif argtyp <: MtlPtr
223223
# the same as a buffer, but with an offset
224224
push!(ex.args, :(set_buffer!(cce, $argex.buffer, $argex.offset, $idx)))
225225
elseif isghosttype(argtyp) || Core.Compiler.isconstType(argtyp)

src/memory.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,30 @@
77
# we cannot take a MTLBuffer's handle and work with that as it were a pointer to memory.
88
# instead, the Metal APIs always take the original handle and an offset parameter.
99

10-
struct MtlPointer{T}
10+
struct MtlPtr{T}
1111
buffer::MTLBuffer
1212
offset::UInt # in bytes
1313

14-
function MtlPointer{T}(buffer::MTLBuffer, offset=0) where {T}
14+
function MtlPtr{T}(buffer::MTLBuffer, offset=0) where {T}
1515
new(buffer, offset)
1616
end
1717
end
1818

19-
Base.eltype(::Type{<:MtlPointer{T}}) where {T} = T
19+
Base.eltype(::Type{<:MtlPtr{T}}) where {T} = T
2020

2121
# limited arithmetic
22-
Base.:(+)(x::MtlPointer{T}, y::Integer) where {T} = MtlPointer{T}(x.buffer, x.offset+y)
23-
Base.:(-)(x::MtlPointer{T}, y::Integer) where {T} = MtlPointer{T}(x.buffer, x.offset-y)
24-
Base.:(+)(x::Integer, y::MtlPointer{T}) where {T} = MtlPointer{T}(x.buffer, y+x.offset)
22+
Base.:(+)(x::MtlPtr{T}, y::Integer) where {T} = MtlPtr{T}(x.buffer, x.offset+y)
23+
Base.:(-)(x::MtlPtr{T}, y::Integer) where {T} = MtlPtr{T}(x.buffer, x.offset-y)
24+
Base.:(+)(x::Integer, y::MtlPtr{T}) where {T} = MtlPtr{T}(x.buffer, y+x.offset)
2525

26-
Base.convert(::Type{Ptr{T}}, ptr::MtlPointer) where {T} =
26+
Base.convert(::Type{Ptr{T}}, ptr::MtlPtr) where {T} =
2727
convert(Ptr{T}, ptr.buffer) + ptr.offset
2828

2929

3030
## operations
3131

3232
# CPU -> GPU
33-
function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPointer{T}, src::Ptr{T}, N::Integer;
33+
function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPtr{T}, src::Ptr{T}, N::Integer;
3434
queue::MTLCommandQueue=global_queue(dev), async::Bool=false) where T
3535
storage_type = dst.buffer.storageMode
3636
if storage_type == MTL.MTLStorageModePrivate
@@ -39,7 +39,7 @@ function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPointer{T}, src::Ptr{T}, N:
3939
# unsafe_copyto!(dev, dst, pointer(shared), N; queue, async=false)
4040
# free(shared)
4141
tmp_buf = alloc(dev, N*sizeof(T), src; storage=Shared) #CPU -> GPU (Shared)
42-
unsafe_copyto!(dev, MtlPointer{T}(dst.buffer, dst.offset), MtlPointer{T}(tmp_buf, 0), N; queue, async=false) # GPU (Shared) -> GPU (Private)
42+
unsafe_copyto!(dev, MtlPtr{T}(dst.buffer, dst.offset), MtlPtr{T}(tmp_buf, 0), N; queue, async=false) # GPU (Shared) -> GPU (Private)
4343
free(tmp_buf)
4444
elseif storage_type == MTL.MTLStorageModeShared
4545
unsafe_copyto!(convert(Ptr{T}, dst), src, N)
@@ -51,13 +51,13 @@ function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPointer{T}, src::Ptr{T}, N:
5151
end
5252

5353
# GPU -> CPU
54-
function Base.unsafe_copyto!(dev::MTLDevice, dst::Ptr{T}, src::MtlPointer{T}, N::Integer;
54+
function Base.unsafe_copyto!(dev::MTLDevice, dst::Ptr{T}, src::MtlPtr{T}, N::Integer;
5555
queue::MTLCommandQueue=global_queue(dev), async::Bool=false) where T
5656
storage_type = src.buffer.storageMode
5757
if storage_type == MTL.MTLStorageModePrivate
5858
# stage through a shared buffer
5959
shared = alloc(dev, N*sizeof(T); storage=Shared)
60-
unsafe_copyto!(dev, MtlPointer{T}(shared, 0), MtlPointer{T}(src.buffer, src.offset), N; queue, async=false)
60+
unsafe_copyto!(dev, MtlPtr{T}(shared, 0), MtlPtr{T}(src.buffer, src.offset), N; queue, async=false)
6161
unsafe_copyto!(dst, convert(Ptr{T}, shared), N)
6262
free(shared)
6363
elseif storage_type == MTL.MTLStorageModeShared
@@ -69,8 +69,8 @@ function Base.unsafe_copyto!(dev::MTLDevice, dst::Ptr{T}, src::MtlPointer{T}, N:
6969
end
7070

7171
# GPU -> GPU
72-
@autoreleasepool function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPointer{T},
73-
src::MtlPointer{T}, N::Integer;
72+
@autoreleasepool function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPtr{T},
73+
src::MtlPtr{T}, N::Integer;
7474
queue::MTLCommandQueue=global_queue(dev),
7575
async::Bool=false) where T
7676
cmdbuf = MTLCommandBuffer(queue)
@@ -81,7 +81,7 @@ end
8181
async || wait_completed(cmdbuf)
8282
end
8383

84-
@autoreleasepool function unsafe_fill!(dev::MTLDevice, ptr::MtlPointer{T},
84+
@autoreleasepool function unsafe_fill!(dev::MTLDevice, ptr::MtlPtr{T},
8585
value::Union{UInt8,Int8}, N::Integer) where T
8686
cmdbuf = MTLCommandBuffer(global_queue(dev))
8787
MTLBlitCommandEncoder(cmdbuf) do enc

test/metal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ end
454454
arr = Metal.zeros(T, 4)
455455

456456
buf = Base.unsafe_convert(MTL.MTLBuffer, arr)
457-
Metal.unsafe_fill!(current_device(), Metal.MtlPointer{T}(buf, 0), T(val), 4)
457+
Metal.unsafe_fill!(current_device(), Metal.MtlPtr{T}(buf, 0), T(val), 4)
458458

459459
@test all(Array(arr) .== val)
460460
end

0 commit comments

Comments
 (0)