Skip to content

Commit ae64a53

Browse files
vkuzopytorchmergebot
authored andcommitted
make the float4 dtype support equality comparisons (pytorch#169575)
Summary: Makes `torch.allclose(a, b, atol=0, rtol=0)` work for `a` and `b` with dtype `torch.float4_e2m1fn_x2`. This is useful for testing. Test Plan: ``` pytest test/quantization/core/experimental/test_floatx.py -s -k test_float4_e2m1fn_x2 ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#169575 Approved by: https://github.com/eqy, https://github.com/drisspg
1 parent 82e30f3 commit ae64a53

File tree

4 files changed

+29
-9
lines changed

4 files changed

+29
-9
lines changed

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -624,38 +624,38 @@ void ge_kernel(TensorIteratorBase& iter) {
624624
void eq_kernel(TensorIteratorBase& iter) {
625625
// See Note [special-case bool outputs]
626626
if (iter.dtype() == ScalarType::Bool) {
627-
_AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
627+
AT_DISPATCH_V2(iter.common_dtype(), "eq_cpu", AT_WRAP([&]() {
628628
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a == b; });
629-
});
629+
}), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2);
630630
} else {
631-
_AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
631+
AT_DISPATCH_V2(iter.common_dtype(), "eq_cpu", AT_WRAP([&]() {
632632
cpu_kernel_vec(
633633
iter,
634634
[](scalar_t a, scalar_t b) -> scalar_t {
635635
return static_cast<scalar_t>(a == b);
636636
},
637637
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
638638
-> Vectorized<scalar_t> { return a.eq(b); });
639-
});
639+
}), kComplexHalf, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2);
640640
}
641641
}
642642

643643
void ne_kernel(TensorIteratorBase& iter) {
644644
// See Note [special-case bool outputs]
645645
if (iter.dtype() == ScalarType::Bool) {
646-
_AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
646+
AT_DISPATCH_V2(iter.common_dtype(), "ne_cpu", AT_WRAP([&]() {
647647
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a != b; });
648-
});
648+
}), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2);
649649
} else {
650-
_AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
650+
AT_DISPATCH_V2(iter.common_dtype(), "ne_cpu", AT_WRAP([&]() {
651651
cpu_kernel_vec(
652652
iter,
653653
[](scalar_t a, scalar_t b) -> scalar_t {
654654
return static_cast<scalar_t>(a != b);
655655
},
656656
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
657657
-> Vectorized<scalar_t> { return a.ne(b); });
658-
});
658+
}), kComplexHalf, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2);
659659
}
660660
}
661661

aten/src/ATen/native/cuda/CompareEQKernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) {
3333
AT_DISPATCH_V2(iter.common_dtype(), "compare_eq_ne_cuda", AT_WRAP([&]() {
3434
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
3535
iter, CompareEqFunctor<scalar_t>(op));
36-
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
36+
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kFloat4_e2m1fn_x2);
3737
}
3838

3939
void eq_kernel_cuda(TensorIteratorBase& iter) {

test/quantization/core/experimental/test_floatx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["oncall: quantization"]
22

3+
import copy
34
import struct
45
import unittest
56

@@ -407,6 +408,10 @@ def test_float4_e2m1fn_x2(self, device):
407408
# can view uint8 as float4_e2m1fn_x2
408409
x2.view(torch.float4_e2m1fn_x2)
409410

411+
# can do equality comparisons
412+
x3 = copy.deepcopy(x1)
413+
self.assertEqual(x1, x3, atol=0, rtol=0)
414+
410415
def test_f4_save_load(self, device):
411416
x1 = torch.randint(0, 10, (4, 4), device=device, dtype=torch.uint8).view(
412417
torch.float4_e2m1fn_x2

torch/headeronly/util/Float4_e2m1fn_x2.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,23 @@ struct alignas(1) Float4_e2m1fn_x2 {
2525
C10_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {}
2626
};
2727

28+
/// Comparison operators
29+
inline C10_HOST_DEVICE bool operator==(
30+
const Float4_e2m1fn_x2& a,
31+
const Float4_e2m1fn_x2& b) {
32+
return a.val_ == b.val_;
33+
}
34+
35+
inline C10_HOST_DEVICE bool operator!=(
36+
const Float4_e2m1fn_x2& a,
37+
const Float4_e2m1fn_x2& b) {
38+
return a.val_ != b.val_;
39+
}
40+
2841
} // namespace c10
2942

3043
HIDDEN_NAMESPACE_BEGIN(torch, headeronly)
3144
using c10::Float4_e2m1fn_x2;
45+
using c10::operator==;
46+
using c10::operator!=;
3247
HIDDEN_NAMESPACE_END(torch, headeronly)

0 commit comments

Comments
 (0)