Skip to content

Commit 7065544

Browse files
committed
Explictly call Ord::cmp to compare priorities
Previously we would call `PartialOrd::partial_cmp` implicilty using the overloaded comparison operators `>` and `<` when comparing priorities, despite the fact that P is required to be `Ord`. A well behaved implementation of `partial_cmp` is supposed to return `Some(Ord::cmp(self, other))`. That relationship is only a convention, so a misbehaved implementation may return None, which may cause the order items are popped from the queue to behave seemingly randomly. We can be a bit more defensive here and instead always call `Ord::cmp` directly, ensuring that we never try to compare things that could possibly return `None`. In order to enforce this going forward I added a test that panics in the implementation for partial_cmp and exercised all of the code paths that might call it. This isn't perfect, since new callsites could be added, but I figure its probably good enough for now. Not sure exactly how to version this, it's not a breaking change if you implementation of `PartialOrd` follows the convention, but if you don't it would be a breaking change, since we'd be using a different function for comparisons.
1 parent 95499eb commit 7065544

File tree

4 files changed

+126
-18
lines changed

4 files changed

+126
-18
lines changed

src/double_priority_queue/mod.rs

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,10 @@ where
665665
///
666666
/// Computes in **O(log(N))** time.
667667
pub fn push_increase(&mut self, item: I, priority: P) -> Option<P> {
668-
if self.get_priority(&item).map_or(true, |p| priority > *p) {
668+
if self
669+
.get_priority(&item)
670+
.map_or(true, |p| priority.cmp(p).is_gt())
671+
{
669672
self.push(item, priority)
670673
} else {
671674
Some(priority)
@@ -705,7 +708,10 @@ where
705708
///
706709
/// Computes in **O(log(N))** time.
707710
pub fn push_decrease(&mut self, item: I, priority: P) -> Option<P> {
708-
if self.get_priority(&item).map_or(true, |p| priority < *p) {
711+
if self
712+
.get_priority(&item)
713+
.map_or(true, |p| priority.cmp(p).is_lt())
714+
{
709715
self.push(item, priority)
710716
} else {
711717
Some(priority)
@@ -901,15 +907,20 @@ where
901907
.0;
902908

903909
if unsafe {
904-
self.store.get_priority_from_position(i) < self.store.get_priority_from_position(m)
910+
self.store
911+
.get_priority_from_position(i)
912+
.cmp(self.store.get_priority_from_position(m))
913+
.is_lt()
905914
} {
906915
self.store.swap(i, m);
907916
if i > r {
908917
// i is a grandchild of m
909918
let p = parent(i);
910919
if unsafe {
911-
self.store.get_priority_from_position(i)
912-
> self.store.get_priority_from_position(p)
920+
self.store
921+
.get_priority_from_position(i)
922+
.cmp(self.store.get_priority_from_position(p))
923+
.is_gt()
913924
} {
914925
self.store.swap(i, p);
915926
}
@@ -943,15 +954,20 @@ where
943954
.0;
944955

945956
if unsafe {
946-
self.store.get_priority_from_position(i) > self.store.get_priority_from_position(m)
957+
self.store
958+
.get_priority_from_position(i)
959+
.cmp(self.store.get_priority_from_position(m))
960+
.is_gt()
947961
} {
948962
self.store.swap(i, m);
949963
if i > r {
950964
// i is a grandchild of m
951965
let p = parent(i);
952966
if unsafe {
953-
self.store.get_priority_from_position(i)
954-
< self.store.get_priority_from_position(p)
967+
self.store
968+
.get_priority_from_position(i)
969+
.cmp(self.store.get_priority_from_position(p))
970+
.is_lt()
955971
} {
956972
self.store.swap(i, p);
957973
}
@@ -970,7 +986,10 @@ where
970986
let parent = parent(position);
971987
let parent_priority = unsafe { self.store.get_priority_from_position(parent) };
972988
let parent_index = unsafe { *self.store.heap.get_unchecked(parent.0) };
973-
position = match (level(position) % 2 == 0, parent_priority < priority) {
989+
position = match (
990+
level(position) % 2 == 0,
991+
parent_priority.cmp(priority).is_lt(),
992+
) {
974993
// on a min level and greater then parent
975994
(true, true) => {
976995
unsafe {
@@ -1008,7 +1027,9 @@ where
10081027
let mut grand_parent = Position(0);
10091028
while if position.0 > 0 && parent(position).0 > 0 {
10101029
grand_parent = parent(parent(position));
1011-
(unsafe { self.store.get_priority_from_position(grand_parent) }) > priority
1030+
(unsafe { self.store.get_priority_from_position(grand_parent) })
1031+
.cmp(priority)
1032+
.is_gt()
10121033
} else {
10131034
false
10141035
} {
@@ -1027,7 +1048,9 @@ where
10271048
let mut grand_parent = Position(0);
10281049
while if position.0 > 0 && parent(position).0 > 0 {
10291050
grand_parent = parent(parent(position));
1030-
(unsafe { self.store.get_priority_from_position(grand_parent) }) < priority
1051+
(unsafe { self.store.get_priority_from_position(grand_parent) })
1052+
.cmp(priority)
1053+
.is_lt()
10311054
} else {
10321055
false
10331056
} {

src/priority_queue/mod.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,10 @@ where
539539
///
540540
/// Computes in **O(log(N))** time.
541541
pub fn push_increase(&mut self, item: I, priority: P) -> Option<P> {
542-
if self.get_priority(&item).map_or(true, |p| priority > *p) {
542+
if self
543+
.get_priority(&item)
544+
.map_or(true, |p| priority.cmp(p).is_gt())
545+
{
543546
self.push(item, priority)
544547
} else {
545548
Some(priority)
@@ -579,7 +582,10 @@ where
579582
///
580583
/// Computes in **O(log(N))** time.
581584
pub fn push_decrease(&mut self, item: I, priority: P) -> Option<P> {
582-
if self.get_priority(&item).map_or(true, |p| priority < *p) {
585+
if self
586+
.get_priority(&item)
587+
.map_or(true, |p| priority.cmp(p).is_lt())
588+
{
583589
self.push(item, priority)
584590
} else {
585591
Some(priority)
@@ -762,12 +768,16 @@ where
762768
let mut largestp = unsafe { self.store.get_priority_from_position(i) };
763769
if l.0 < self.len() {
764770
let childp = unsafe { self.store.get_priority_from_position(l) };
765-
if childp > largestp {
771+
if childp.cmp(largestp).is_gt() {
766772
largest = l;
767773
largestp = childp;
768774
}
769775

770-
if r.0 < self.len() && unsafe { self.store.get_priority_from_position(r) } > largestp {
776+
if r.0 < self.len()
777+
&& unsafe { self.store.get_priority_from_position(r) }
778+
.cmp(largestp)
779+
.is_gt()
780+
{
771781
largest = r;
772782
}
773783
}
@@ -780,14 +790,16 @@ where
780790
l = left(i);
781791
if l.0 < self.len() {
782792
let childp = unsafe { self.store.get_priority_from_position(l) };
783-
if childp > largestp {
793+
if childp.cmp(largestp).is_gt() {
784794
largest = l;
785795
largestp = childp;
786796
}
787797

788798
r = right(i);
789799
if r.0 < self.len()
790-
&& unsafe { self.store.get_priority_from_position(r) } > largestp
800+
&& unsafe { self.store.get_priority_from_position(r) }
801+
.cmp(largestp)
802+
.is_gt()
791803
{
792804
largest = r;
793805
}
@@ -802,7 +814,9 @@ where
802814
let mut parent_position = Position(0);
803815
while if position.0 > 0 {
804816
parent_position = parent(position);
805-
(unsafe { self.store.get_priority_from_position(parent_position) }) < priority
817+
(unsafe { self.store.get_priority_from_position(parent_position) })
818+
.cmp(priority)
819+
.is_lt()
806820
} else {
807821
false
808822
} {

tests/double_priority_queue.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,41 @@ mod doublepq_tests {
12751275
);
12761276
}
12771277
}
1278+
1279+
#[test]
1280+
fn partial_cmp_not_called() {
1281+
use std::cmp::{Ordering, PartialOrd};
1282+
1283+
#[derive(Debug, PartialEq, Eq, Hash, Ord)]
1284+
struct PanicPartial(i64);
1285+
1286+
// This is an invalid implementation of PartialOrd according to
1287+
// the docs in `std::cmp::PartialOrd`, since Ord is also implemented,
1288+
// this should always return Some(Ord::cmp(self, other)). Instead this
1289+
// function panics as a way to ensure that we don't accidently
1290+
// use PartialOrd::partial_cmp when we should be using Ord::cmp
1291+
// instead. Enforcing the explicit use of Ord::cmp lets us rely on
1292+
// the compiler instead of the convention that PartialOrd::partial_cmp
1293+
// _should_ call Ord::cmp
1294+
impl PartialOrd for PanicPartial {
1295+
fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
1296+
panic!("partial_cmp should not be called");
1297+
}
1298+
}
1299+
1300+
let mut dpq = DoublePriorityQueue::new();
1301+
dpq.push(0, PanicPartial(100));
1302+
dpq.push(1, PanicPartial(200));
1303+
dpq.push(2, PanicPartial(150));
1304+
dpq.push_increase(2, PanicPartial(300));
1305+
dpq.push_decrease(2, PanicPartial(0));
1306+
1307+
// These asserts are redundant since this behavior is tested elsewhere, we're
1308+
// mainly just interested in not panicking for this test.
1309+
assert_eq!(dpq.pop_min(), Some((2, PanicPartial(0))));
1310+
assert_eq!(dpq.pop_max(), Some((1, PanicPartial(200))));
1311+
assert_eq!(dpq.pop_min(), Some((0, PanicPartial(100))));
1312+
}
12781313
}
12791314

12801315
#[cfg(all(feature = "serde", test))]

tests/priority_queue.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,42 @@ mod pqueue_tests {
11431143
assert_eq!(removed_priority, 200);
11441144
assert!(!pq.contains(&bob_view));
11451145
}
1146+
1147+
#[test]
1148+
fn partial_cmp_not_called() {
1149+
use std::cmp::{Ordering, PartialOrd};
1150+
1151+
#[derive(Debug, PartialEq, Eq, Hash, Ord)]
1152+
struct PanicPartial(i64);
1153+
1154+
// This is an invalid implementation of PartialOrd according to
1155+
// the docs in `std::cmp::PartialOrd`, since Ord is also implemented,
1156+
// this should always return Some(Ord::cmp(self, other)). Instead this
1157+
// function panics as a way to ensure that we don't accidently
1158+
// use PartialOrd::partial_cmp when we should be using Ord::cmp
1159+
// instead. Enforcing the explicit use of Ord::cmp lets us rely on
1160+
// the compiler instead of the convention that PartialOrd::partial_cmp
1161+
// _should_ call Ord::cmp
1162+
impl PartialOrd for PanicPartial {
1163+
fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
1164+
panic!("partial_cmp should not be called");
1165+
}
1166+
}
1167+
1168+
// Push persons into queue
1169+
let mut pq = PriorityQueue::new();
1170+
pq.push(0, PanicPartial(100));
1171+
pq.push(1, PanicPartial(200));
1172+
pq.push(2, PanicPartial(150));
1173+
pq.push_increase(2, PanicPartial(300));
1174+
pq.push_decrease(2, PanicPartial(0));
1175+
1176+
// These asserts are redundant since this behavior is tested elsewhere, we're
1177+
// mainly just interested in not panicking for this test.
1178+
assert_eq!(pq.pop(), Some((1, PanicPartial(200))));
1179+
assert_eq!(pq.pop(), Some((0, PanicPartial(100))));
1180+
assert_eq!(pq.pop(), Some((2, PanicPartial(0))));
1181+
}
11461182
}
11471183

11481184
#[cfg(all(feature = "serde", test))]

0 commit comments

Comments
 (0)