Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/cudadrv/CUDAdrv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ using Printf

using LazyArtifacts

# Julia has several notions of `sizeof`
# - Base.sizeof is the size of an object in memory
# - Base.aligned_sizeof is the size of an object in an array/inline alloced
# Both of them are equivalent for immutable objects, but differ for mutable singtons and Symbol
# We use `aligned_sizeof` since we care about the size of a type in an array
import Base: aligned_sizeof

# low-level wrappers
include("libcuda.jl")
Expand Down
32 changes: 16 additions & 16 deletions lib/cudadrv/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ for (fn, srcPtrTy, dstPtrTy) in (("cuMemcpyDtoHAsync_v2", :CuPtr, :Ptr),
@eval function Base.unsafe_copyto!(dst::$dstPtrTy{T}, src::$srcPtrTy{T}, N::Integer;
stream::CuStream=stream(),
async::Bool=false) where T
$(getproperty(CUDA, Symbol(fn)))(dst, src, N*sizeof(T), stream)
$(getproperty(CUDA, Symbol(fn)))(dst, src, N*aligned_sizeof(T), stream)
async || synchronize(stream)
return dst
end
Expand All @@ -423,11 +423,11 @@ function Base.unsafe_copyto!(dst::CuPtr{T}, src::CuPtr{T}, N::Integer;
dst_dev = device(dst)
src_dev = device(src)
if dst_dev == src_dev
cuMemcpyDtoDAsync_v2(dst, src, N*sizeof(T), stream)
cuMemcpyDtoDAsync_v2(dst, src, N*aligned_sizeof(T), stream)
else
cuMemcpyPeerAsync(dst, context(dst_dev),
src, context(src_dev),
N*sizeof(T), stream)
N*aligned_sizeof(T), stream)
end
async || synchronize(stream)
return dst
Expand All @@ -436,24 +436,24 @@ end
function Base.unsafe_copyto!(dst::CuArrayPtr{T}, doffs::Integer, src::Ptr{T}, N::Integer;
stream::CuStream=stream(),
async::Bool=false) where T
cuMemcpyHtoAAsync_v2(dst, doffs, src, N*sizeof(T), stream)
cuMemcpyHtoAAsync_v2(dst, doffs, src, N*aligned_sizeof(T), stream)
async || synchronize(stream)
return dst
end

function Base.unsafe_copyto!(dst::Ptr{T}, src::CuArrayPtr{T}, soffs::Integer, N::Integer;
stream::CuStream=stream(),
async::Bool=false) where T
cuMemcpyAtoHAsync_v2(dst, src, soffs, N*sizeof(T), stream)
cuMemcpyAtoHAsync_v2(dst, src, soffs, N*aligned_sizeof(T), stream)
async || synchronize(stream)
return dst
end

Base.unsafe_copyto!(dst::CuArrayPtr{T}, doffs::Integer, src::CuPtr{T}, N::Integer) where {T} =
cuMemcpyDtoA_v2(dst, doffs, src, N*sizeof(T))
cuMemcpyDtoA_v2(dst, doffs, src, N*aligned_sizeof(T))

Base.unsafe_copyto!(dst::CuPtr{T}, src::CuArrayPtr{T}, soffs::Integer, N::Integer) where {T} =
cuMemcpyAtoD_v2(dst, src, soffs, N*sizeof(T))
cuMemcpyAtoD_v2(dst, src, soffs, N*aligned_sizeof(T))

Base.unsafe_copyto!(dst::CuArrayPtr, src, N::Integer; kwargs...) =
Base.unsafe_copyto!(dst, 0, src, N; kwargs...)
Expand Down Expand Up @@ -529,15 +529,15 @@ function unsafe_copy2d!(dst::Union{Ptr{T},CuPtr{T},CuArrayPtr{T}}, dstTyp::Type{

params_ref = Ref(CUDA_MEMCPY2D(
# source
(srcPos.x-1)*sizeof(T), srcPos.y-1,
(srcPos.x-1)*aligned_sizeof(T), srcPos.y-1,
srcMemoryType, srcHost, srcDevice, srcArray,
srcPitch,
# destination
(dstPos.x-1)*sizeof(T), dstPos.y-1,
(dstPos.x-1)*aligned_sizeof(T), dstPos.y-1,
dstMemoryType, dstHost, dstDevice, dstArray,
dstPitch,
# extent
width*sizeof(T), height
width*aligned_sizeof(T), height
))
cuMemcpy2DAsync_v2(params_ref, stream)
async || synchronize(stream)
Expand Down Expand Up @@ -569,8 +569,8 @@ function unsafe_copy3d!(dst::Union{Ptr{T},CuPtr{T},CuArrayPtr{T}}, dstTyp::Type{
# when using the stream-ordered memory allocator
# NOTE: we apply the workaround unconditionally, since we want to keep this call cheap.
if v"11.2" <= driver_version() <= v"11.3" #&& pools[device()].stream_ordered
srcOffset = (srcPos.x-1)*sizeof(T) + srcPitch*((srcPos.y-1) + srcHeight*(srcPos.z-1))
dstOffset = (dstPos.x-1)*sizeof(T) + dstPitch*((dstPos.y-1) + dstHeight*(dstPos.z-1))
srcOffset = (srcPos.x-1)*aligned_sizeof(T) + srcPitch*((srcPos.y-1) + srcHeight*(srcPos.z-1))
dstOffset = (dstPos.x-1)*aligned_sizeof(T) + dstPitch*((dstPos.y-1) + dstHeight*(dstPos.z-1))
else
srcOffset = 0
dstOffset = 0
Expand Down Expand Up @@ -622,23 +622,23 @@ function unsafe_copy3d!(dst::Union{Ptr{T},CuPtr{T},CuArrayPtr{T}}, dstTyp::Type{

params_ref = Ref(CUDA_MEMCPY3D(
# source
srcOffset==0 ? (srcPos.x-1)*sizeof(T) : 0,
srcOffset==0 ? (srcPos.x-1)*aligned_sizeof(T) : 0,
srcOffset==0 ? srcPos.y-1 : 0,
srcOffset==0 ? srcPos.z-1 : 0,
0, # LOD
srcMemoryType, srcHost, srcDevice, srcArray,
C_NULL, # reserved
srcPitch, srcHeight,
# destination
dstOffset==0 ? (dstPos.x-1)*sizeof(T) : 0,
dstOffset==0 ? (dstPos.x-1)*aligned_sizeof(T) : 0,
dstOffset==0 ? dstPos.y-1 : 0,
dstOffset==0 ? dstPos.z-1 : 0,
0, # LOD
dstMemoryType, dstHost, dstDevice, dstArray,
C_NULL, # reserved
dstPitch, dstHeight,
# extent
width*sizeof(T), height, depth
width*aligned_sizeof(T), height, depth
))
cuMemcpy3DAsync_v2(params_ref, stream)
async || synchronize(stream)
Expand Down Expand Up @@ -698,7 +698,7 @@ function pin(ref::Base.RefValue{T}) where T
ctx = context()
ptr = Base.unsafe_convert(Ptr{T}, ref)

__pin(ptr, sizeof(T))
__pin(ptr, aligned_sizeof(T))
finalizer(ref) do _
__unpin(ptr, ctx)
end
Expand Down
2 changes: 1 addition & 1 deletion lib/cudadrv/module/global.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct CuGlobal{T}
ptr_ref = Ref{CuPtr{Cvoid}}()
nbytes_ref = Ref{Csize_t}()
cuModuleGetGlobal_v2(ptr_ref, nbytes_ref, mod, name)
if nbytes_ref[] != sizeof(T)
if nbytes_ref[] != aligned_sizeof(T)
throw(ArgumentError("size of global '$name' does not match type parameter type $T"))
end
buf = DeviceMemory(device(), context(), ptr_ref[], nbytes_ref[], false)
Expand Down
12 changes: 12 additions & 0 deletions src/CUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ import NVTX

using Printf

# Julia has several notions of `sizeof`
# - Base.sizeof is the size of an object in memory
# - Base.aligned_sizeof is the size of an object in an array/inline alloced
# Both of them are equivalent for immutable objects, but differ for mutable singtons and Symbol
# We use `aligned_sizeof` since we care about the size of a type in an array
@static if VERSION < v"1.11.0"
@generated function aligned_sizeof(::Type{T}) where T
return :($(Base.aligned_sizeof(T)))
end
else
import Base: aligned_sizeof
end

## source code includes

Expand Down
16 changes: 8 additions & 8 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ mutable struct CuArray{T,N,M} <: AbstractGPUArray{T,N}

function CuArray{T,N,M}(::UndefInitializer, dims::Dims{N}) where {T,N,M}
check_eltype("CuArray", T)
maxsize = prod(dims) * sizeof(T)
maxsize = prod(dims) * aligned_sizeof(T)
bufsize = if Base.isbitsunion(T)
# type tag array past the data
maxsize + prod(dims)
Expand All @@ -84,7 +84,7 @@ mutable struct CuArray{T,N,M} <: AbstractGPUArray{T,N}
end

function CuArray{T,N}(data::DataRef{Managed{M}}, dims::Dims{N};
maxsize::Int=prod(dims) * sizeof(T), offset::Int=0) where {T,N,M}
maxsize::Int=prod(dims) * aligned_sizeof(T), offset::Int=0) where {T,N,M}
check_eltype("CuArray", T)
obj = new{T,N,M}(data, maxsize, offset, dims)
finalizer(unsafe_free!, obj)
Expand Down Expand Up @@ -235,7 +235,7 @@ function Base.unsafe_wrap(::Type{CuArray{T,N,M}},
ptr::CuPtr{T}, dims::NTuple{N,Int};
own::Bool=false, ctx::CuContext=context()) where {T,N,M}
isbitstype(T) || throw(ArgumentError("Can only unsafe_wrap a pointer to a bits type"))
sz = prod(dims) * sizeof(T)
sz = prod(dims) * aligned_sizeof(T)

# create a memory object
mem = if M == UnifiedMemory
Expand Down Expand Up @@ -290,7 +290,7 @@ supports_hmm(dev) = driver_version() >= v"12.2" &&
function Base.unsafe_wrap(::Type{CuArray{T,N,M}}, p::Ptr{T}, dims::NTuple{N,Int};
ctx::CuContext=context()) where {T,N,M<:AbstractMemory}
isbitstype(T) || throw(ArgumentError("Can only unsafe_wrap a pointer to a bits type"))
sz = prod(dims) * sizeof(T)
sz = prod(dims) * aligned_sizeof(T)

data = if M == UnifiedMemory
# HMM extends unified memory to include system memory
Expand Down Expand Up @@ -338,7 +338,7 @@ Base.unsafe_wrap(::Type{CuArray{T,N,M}}, a::Array{T,N}) where {T,N,M} =

## array interface

Base.elsize(::Type{<:CuArray{T}}) where {T} = sizeof(T)
Base.elsize(::Type{<:CuArray{T}}) where {T} = aligned_sizeof(T)

Base.size(x::CuArray) = x.dims
Base.sizeof(x::CuArray) = Base.elsize(x) * length(x)
Expand Down Expand Up @@ -837,7 +837,7 @@ end
## derived arrays

function GPUArrays.derive(::Type{T}, a::CuArray, dims::Dims{N}, offset::Int) where {T,N}
offset = (a.offset * Base.elsize(a)) ÷ sizeof(T) + offset
offset = (a.offset * Base.elsize(a)) ÷ aligned_sizeof(T) + offset
CuArray{T,N}(copy(a.data), dims; a.maxsize, offset)
end

Expand All @@ -851,7 +851,7 @@ function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{
end
function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{Base.RangeIndex,Base.ReshapedUnitRange}}}}) where {T,N,P}
return Base.unsafe_convert(CuPtr{T}, parent(V)) +
(Base.first_index(V)-1)*sizeof(T)
(Base.first_index(V)-1)*aligned_sizeof(T)
end


Expand All @@ -874,7 +874,7 @@ function Base.resize!(A::CuVector{T}, n::Integer) where T
n == length(A) && return A

# TODO: add additional space to allow for quicker resizing
maxsize = n * sizeof(T)
maxsize = n * aligned_sizeof(T)
bufsize = if isbitstype(T)
maxsize
else
Expand Down
9 changes: 2 additions & 7 deletions src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,6 @@ end
CompilerConfig(target, params; kernel, name, always_inline)
end

# a version of `sizeof` that returns the size of the argument we'll pass.
# for example, it supports Symbols where `sizeof(Symbol)` would fail.
argsize(x::Any) = sizeof(x)
argsize(::Type{Symbol}) = sizeof(Ptr{Cvoid})

# compile to executable machine code
function compile(@nospecialize(job::CompilerJob))
# lower to PTX
Expand Down Expand Up @@ -286,7 +281,7 @@ function compile(@nospecialize(job::CompilerJob))
argtypes = filter([KernelState, job.source.specTypes.parameters...]) do dt
!isghosttype(dt) && !Core.Compiler.isconstType(dt)
end
param_usage = sum(argsize, argtypes)
param_usage = sum(aligned_sizeof, argtypes)
param_limit = 4096
if cap >= v"7.0" && ptx >= v"8.1"
param_limit = 32764
Expand All @@ -310,7 +305,7 @@ function compile(@nospecialize(job::CompilerJob))
continue
end
name = source_argnames[i]
details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(sizeof(typ)))"
details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(aligned_sizeof(typ)))"
end
details *= "\n"

Expand Down
8 changes: 4 additions & 4 deletions src/device/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct CuDeviceArray{T,N,A} <: DenseArray{T,N}

# inner constructors, fully parameterized, exact types (ie. Int not <:Integer)
CuDeviceArray{T,N,A}(ptr::LLVMPtr{T,A}, dims::Tuple,
maxsize::Int=prod(dims)*sizeof(T)) where {T,A,N} =
maxsize::Int=prod(dims)*aligned_sizeof(T)) where {T,A,N} =
new(ptr, maxsize, dims, prod(dims))
end

Expand All @@ -39,7 +39,7 @@ const CuDeviceMatrix = CuDeviceArray{T,2,A} where {T,A}

## array interface

Base.elsize(::Type{<:CuDeviceArray{T}}) where {T} = sizeof(T)
Base.elsize(::Type{<:CuDeviceArray{T}}) where {T} = aligned_sizeof(T)

Base.size(g::CuDeviceArray) = g.dims
Base.sizeof(x::CuDeviceArray) = Base.elsize(x) * length(x)
Expand Down Expand Up @@ -239,12 +239,12 @@ function Base.reinterpret(::Type{T}, a::CuDeviceArray{S,N,A}) where {T,S,N,A}
err = GPUArrays._reinterpret_exception(T, a)
err === nothing || throw(err)

if sizeof(T) == sizeof(S) # fast case
if aligned_sizeof(T) == aligned_sizeof(S) # fast case
return CuDeviceArray{T,N,A}(reinterpret(LLVMPtr{T,A}, a.ptr), size(a), a.maxsize)
end

isize = size(a)
size1 = div(isize[1]*sizeof(S), sizeof(T))
size1 = div(isize[1]*aligned_sizeof(S), aligned_sizeof(T))
osize = tuple(size1, Base.tail(isize)...)
return CuDeviceArray{T,N,A}(reinterpret(LLVMPtr{T,A}, a.ptr), osize, a.maxsize)
end
Expand Down
2 changes: 1 addition & 1 deletion src/device/texture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Base.convert(::Type{CUtexObject}, t::CuDeviceTexture) = t.handle

## array interface

Base.elsize(::Type{<:CuDeviceTexture{T}}) where {T} = sizeof(T)
Base.elsize(::Type{<:CuDeviceTexture{T}}) where {T} = aligned_sizeof(T)

Base.size(tm::CuDeviceTexture) = tm.dims
Base.sizeof(tm::CuDeviceTexture) = Base.elsize(x) * length(x)
Expand Down
2 changes: 1 addition & 1 deletion src/refpointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ mutable struct CuRefValue{T} <: AbstractCuRef{T}

function CuRefValue{T}() where {T}
check_eltype("CuRef", T)
buf = pool_alloc(DeviceMemory, sizeof(T))
buf = pool_alloc(DeviceMemory, aligned_sizeof(T))
obj = new(buf)
finalizer(obj) do _
pool_free(buf)
Expand Down
3 changes: 2 additions & 1 deletion src/texture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ Base.size(tm::CuTextureArray) = tm.dims
Base.length(tm::CuTextureArray) = prod(size(tm))

Base.eltype(tm::CuTextureArray{T,N}) where {T,N} = T
Base.elsize(tm::CuTextureArray) = aligned_sizeof(eltype(tm))

Base.sizeof(tm::CuTextureArray) = sizeof(eltype(tm)) * length(tm)
Base.sizeof(tm::CuTextureArray) = Base.elsize(tm) * length(tm)

Base.pointer(t::CuTextureArray) = t.mem.ptr

Expand Down