-
Notifications
You must be signed in to change notification settings - Fork 57
Open
Description
Describe the bug
cupy arrays are slower than pytorch tensors as well as DeviceNDArray
Steps/Code to reproduce bug
pixi run -e cu-12-9-py312 bench
Expected behavior
I would expect cupy to at least match pytorch performance here.
Details
I dug into what's happening and this function:
def from_cuda_array_interface(desc, owner=None, sync=True):
"""Create a DeviceNDArray from a cuda-array-interface description.
The ``owner`` is the owner of the underlying memory.
The resulting DeviceNDArray will acquire a reference from it.
If ``sync`` is ``True``, then the imported stream (if present) will be
synchronized.
"""
version = desc.get("version")
# Mask introduced in version 1
if 1 <= version:
mask = desc.get("mask")
# Would ideally be better to detect if the mask is all valid
if mask is not None:
raise NotImplementedError("Masked arrays are not supported")
shape = desc["shape"]
strides = desc.get("strides")
shape, strides, dtype = prepare_shape_strides_dtype(
shape, strides, desc["typestr"], order="C"
)
size = driver.memory_size_from_info(shape, strides, dtype.itemsize)
cudevptr_class = driver.binding.CUdeviceptr
devptr = cudevptr_class(desc["data"][0])
data = driver.MemoryPointer(devptr, size=size, owner=owner)
stream_ptr = desc.get("stream", None)
if stream_ptr is not None:
stream = external_stream(stream_ptr)
if sync and config.CUDA_ARRAY_INTERFACE_SYNC:
stream.synchronize()
else:
stream = 0 # No "Numba default stream", not the CUDA default stream
da = devicearray.DeviceNDArray(
shape=shape, strides=strides, dtype=dtype, gpu_data=data, stream=stream
)
return dain particular
if stream_ptr is not None:
stream = external_stream(stream_ptr)
if sync and config.CUDA_ARRAY_INTERFACE_SYNC:
stream.synchronize()this branch is hit and external_stream and stream.synchronize() are both called, each of which are relatively expensive.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working