Skip to content

Commit 28c42cc

Browse files
msaroufimpytorchmergebot
authored andcommitted
compile_kernel: Add DLPack test (pytorch#163166)
Note to self: i should probably. start using gh stack This is rebased on top of pytorch#163165 so you only need to review this commit pytorch@7387c1b This test doesn't add any new functionality it just ensures DLPack conversion is working well Pull Request resolved: pytorch#163166 Approved by: https://github.com/janeyx99, https://github.com/albanD
1 parent 0661ecd commit 28c42cc

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

test/test_cuda.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7265,6 +7265,42 @@ def test_compile_kernel_template(self):
72657265
expected = a + b
72667266
self.assertEqual(c, expected)
72677267

7268+
@unittest.skipIf(not TEST_CUDA, "No CUDA")
7269+
def test_compile_kernel_dlpack(self):
7270+
"""Test that compile_kernel works with tensors created via DLPack."""
7271+
kernel_source = """
7272+
__global__ void add_tensors(const float* a, const float* b, float* c, int n) {
7273+
int i = threadIdx.x + blockIdx.x * blockDim.x;
7274+
if (i < n)
7275+
c[i] = a[i] + b[i];
7276+
}
7277+
"""
7278+
7279+
from torch.cuda import _compile_kernel
7280+
7281+
add_kernel = _compile_kernel(kernel_source, "add_tensors")
7282+
7283+
N = 512
7284+
a = torch.rand(N, device="cuda", dtype=torch.float32)
7285+
b = torch.rand(N, device="cuda", dtype=torch.float32)
7286+
7287+
a_dlpack = torch.utils.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(a))
7288+
b_dlpack = torch.utils.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(b))
7289+
c = torch.empty_like(a)
7290+
7291+
threads_per_block = 256
7292+
blocks_per_grid = (N + threads_per_block - 1) // threads_per_block
7293+
7294+
add_kernel(
7295+
grid=(blocks_per_grid, 1, 1),
7296+
block=(threads_per_block, 1, 1),
7297+
args=[a_dlpack, b_dlpack, c, N],
7298+
)
7299+
7300+
self.assertEqual(c, a + b)
7301+
a_dlpack[0] = 42.0
7302+
self.assertEqual(a[0].item(), 42.0, "DLPack tensors should share memory")
7303+
72687304

72697305
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
72707306
class TestCudaDeviceParametrized(TestCase):

0 commit comments

Comments
 (0)