Skip to content

Commit b965d57

Browse files
committed
Improve trait bounds
1 parent 77ab323 commit b965d57

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ const LOG2_SIGNIFICAND: [f64; 16] = [
8181
/// [ops]: https://docs.rs/num-traits/latest/num_traits/trait.NumOps.html
8282
pub trait Minifloat: Copy + PartialEq + PartialOrd + Neg<Output = Self> {
8383
/// Storage type
84-
type Bits: PrimInt + Unsigned;
84+
type Bits: PrimInt + Unsigned + 'static;
8585

8686
/// Whether the type is signed
8787
const S: bool = true;

tests/all.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use core::fmt::Debug;
22
use minifloat::example::*;
33
use minifloat::{minifloat, Minifloat, NanStyle};
4+
use num_traits::AsPrimitive;
45

56
minifloat!(struct F8E2M5(u8): 2, 5);
67
minifloat!(struct F8E2M5FN(u8): 2, 5, FN);
@@ -9,6 +10,14 @@ minifloat!(struct F8E2M5FNUZ(u8): 2, 5, FNUZ);
910
minifloat!(struct F8E3M4FNUZ(u8): 3, 4, FNUZ);
1011
minifloat!(struct F8E5M2FN(u8): 5, 2, FN);
1112

13+
const fn bit_mask(width: u32) -> u64 {
14+
if width == 0 {
15+
0
16+
} else {
17+
u64::MAX >> (64 - width)
18+
}
19+
}
20+
1221
/// Test floating-point identity like Object.is in JavaScript
1322
///
1423
/// This is necessary because NaN != NaN in C++. We also want to differentiate
@@ -32,11 +41,17 @@ fn same_mini<T: Minifloat>(x: T, y: T) -> bool {
3241
x.to_bits() == y.to_bits() || x.is_nan() && y.is_nan()
3342
}
3443

35-
fn for_all<T: Minifloat<Bits = u8>>(f: impl Fn(T) -> bool) -> bool {
36-
(0..=u8::MAX).map(T::from_bits).all(f)
44+
fn for_all<T: Minifloat>(f: impl Fn(T) -> bool) -> bool
45+
where
46+
u64: AsPrimitive<T::Bits>,
47+
{
48+
(0..=bit_mask(T::BITWIDTH)).all(|bits| f(T::from_bits(bits.as_())))
3749
}
3850

39-
fn check_equality<T: Minifloat<Bits = u8> + Debug>() -> bool {
51+
fn check_equality<T: Minifloat + Debug>() -> bool
52+
where
53+
u64: AsPrimitive<T::Bits>,
54+
{
4055
let fixed_point = if T::M == 0 { 2.0 } else { 3.0 };
4156
assert!(same_f32(T::from_f32(fixed_point).to_f32(), fixed_point));
4257

0 commit comments

Comments
 (0)