Skip to content

Commit 3cf2f19

Browse files
vkuzopytorchmergebot
authored andcommitted
add copy_ support for float4 dtype (pytorch#169595)
Summary: Enables `copy_` support for the `torch.float4_e2m1fn_x2` dtype. This is useful when slicing a tensor across dim1 and then calling contiguous, which can happen in vllm and therefore should be supported. Test Plan: ``` pytest test/quantization/core/experimental/test_floatx.py -s -k test_float4_e2m1fn_x2 ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#169595 Approved by: https://github.com/drisspg ghstack dependencies: pytorch#169575
1 parent ae64a53 commit 3cf2f19

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

aten/src/ATen/native/cpu/CopyKernel.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ void direct_copy_kernel(TensorIteratorBase &iter) {
235235
});
236236
} else if (dtype == ScalarType::ComplexHalf) {
237237
cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
238+
} else if (dtype == ScalarType::Float4_e2m1fn_x2) {
239+
cpu_kernel(iter, [=](Float4_e2m1fn_x2 a) -> Float4_e2m1fn_x2 { return a; });
238240
} else if (isBitsType(dtype)) {
239241
AT_DISPATCH_BIT_TYPES(dtype, "copy_kernel", [&] {
240242
cpu_kernel(

aten/src/ATen/native/cuda/Copy.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,10 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
234234
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
235235
gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
236236
});
237+
} else if (dtype == ScalarType::Float4_e2m1fn_x2) {
238+
TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
239+
"Float4_e2m1fn_x2 to different types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
240+
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Float4_e2m1fn_x2 x) { return x; });
237241
} else {
238242
AT_DISPATCH_V2(
239243
dtype, "copy_", AT_WRAP([&] {

test/quantization/core/experimental/test_floatx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,9 @@ def test_float4_e2m1fn_x2(self, device):
412412
x3 = copy.deepcopy(x1)
413413
self.assertEqual(x1, x3, atol=0, rtol=0)
414414

415+
# can call contiguous on a dim1 slice (calls `copy_` under the hood)
416+
x1[:, 0:2048].contiguous()
417+
415418
def test_f4_save_load(self, device):
416419
x1 = torch.randint(0, 10, (4, 4), device=device, dtype=torch.uint8).view(
417420
torch.float4_e2m1fn_x2

0 commit comments

Comments
 (0)