Skip to content

Commit 889a4fe

Browse files
authored
Enhance DLPack compatibility (#7045)
1 parent 3835163 commit 889a4fe

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

libcudacxx/include/cuda/__tma/make_tma_descriptor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,9 @@ __get_tensor_sizes(const ::DLTensor& __tensor, int __rank, ::CUtensorMapDataType
414414
int64_t __cumulative_size = 1;
415415
if (__input_strides == nullptr)
416416
{
417+
# if DLPACK_MAJOR_VERSION > 1 || (DLPACK_MAJOR_VERSION == 1 && DLPACK_MINOR_VERSION >= 2)
418+
_CCCL_THROW(::std::invalid_argument{"__tensor.strides=nullptr is not supported for DLPack v1.2 and later"});
419+
# else
417420
for (int __i = 0; __i < __rank - 1; ++__i)
418421
{
419422
// TODO(fbusato): check mul overflow
@@ -430,6 +433,7 @@ __get_tensor_sizes(const ::DLTensor& __tensor, int __rank, ::CUtensorMapDataType
430433
__output_strides[__i] = __stride_bytes;
431434
}
432435
return __output_strides;
436+
# endif // DLPACK_MAJOR_VERSION > 1 || (DLPACK_MAJOR_VERSION == 1 && DLPACK_MINOR_VERSION >= 2)
433437
}
434438
// TMA ignores the innermost stride (always 1).
435439
for (int __i = __rank - 2; __i >= 0; --__i)

libcudacxx/test/libcudacxx/cuda/tma/make_tma_descriptor.pass.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,6 @@ bool test_strides()
130130
// stride is 0
131131
strides_storage[0] = 0;
132132
unused(cuda::make_tma_descriptor(tensor, box_sizes));
133-
// stride is nullptr
134-
tensor.strides = nullptr;
135-
unused(cuda::make_tma_descriptor(tensor, box_sizes));
136133
return true;
137134
}
138135

0 commit comments

Comments
 (0)