Skip to content

Commit 03d4bc6

Browse files
committed
add tests and fix impl
1 parent a0cb0d0 commit 03d4bc6

File tree

3 files changed

+130
-18
lines changed

3 files changed

+130
-18
lines changed

pallets/subtensor/src/epoch/math.rs

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,33 @@ pub fn inplace_mask_diag(matrix: &mut [Vec<I32F32>]) {
569569
});
570570
}
571571

572+
// Mask out the diagonal of the input matrix in-place, except for the diagonal entry at except_index.
573+
#[allow(dead_code)]
574+
pub fn inplace_mask_diag_except_index(matrix: &mut [Vec<I32F32>], except_index: u16) {
575+
let Some(first_row) = matrix.first() else {
576+
return;
577+
};
578+
if first_row.is_empty() {
579+
return;
580+
}
581+
assert_eq!(matrix.len(), first_row.len());
582+
583+
let diag_at_index = matrix
584+
.get(except_index as usize)
585+
.and_then(|row| row.get(except_index as usize))
586+
.cloned();
587+
588+
inplace_mask_diag(matrix);
589+
590+
matrix.get_mut(except_index as usize).map(|row| {
591+
row.get_mut(except_index as usize).map(|value| {
592+
if let Some(diag_at_index) = diag_at_index {
593+
*value = diag_at_index;
594+
}
595+
})
596+
});
597+
}
598+
572599
// Return a new sparse matrix that replaces masked rows with an empty vector placeholder.
573600
#[allow(dead_code)]
574601
pub fn mask_rows_sparse(
@@ -611,23 +638,20 @@ pub fn mask_diag_sparse_except_index(
611638
sparse_matrix: &[Vec<(u16, I32F32)>],
612639
except_index: u16,
613640
) -> Vec<Vec<(u16, I32F32)>> {
614-
// Store the diagonal entry at except_index
615-
let diag_at_index = sparse_matrix
616-
.get(except_index as usize)
617-
.and_then(|row| row.get(except_index as usize))
618-
.cloned();
619-
// Mask out the diagonal
620-
let mut result = mask_diag_sparse(sparse_matrix);
621-
// Replace the diagonal entry at except_index using only get_mut or map
622-
result.get_mut(except_index as usize).map(|row| {
623-
row.get_mut(except_index as usize).map(|value| {
624-
if let Some(diag_at_index) = diag_at_index {
625-
*value = diag_at_index;
626-
}
641+
sparse_matrix
642+
.iter()
643+
.enumerate()
644+
.map(|(i, sparse_row)| {
645+
sparse_row
646+
.iter()
647+
.filter(|(j, _)| {
648+
// Is not a diagonal OR is the diagonal at except_index
649+
i != (*j as usize) || (i == except_index as usize && *j == except_index)
650+
})
651+
.copied()
652+
.collect()
627653
})
628-
});
629-
630-
result
654+
.collect()
631655
}
632656

633657
// Remove cells from sparse matrix where the mask function of two vectors is true.

pallets/subtensor/src/epoch/run_epoch.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ impl<T: Config> Pallet<T> {
111111
// == Weights ==
112112
// =============
113113

114+
// Get owner uid.
115+
let owner_uid: Option<u16> = Self::get_owner_uid(netuid);
116+
114117
// Access network weights row unnormalized.
115118
let mut weights: Vec<Vec<I32F32>> = Self::get_weights(netuid);
116119
log::trace!("W:\n{:?}\n", &weights);
@@ -119,7 +122,13 @@ impl<T: Config> Pallet<T> {
119122
inplace_mask_rows(&validator_forbids, &mut weights);
120123
log::trace!("W (permit): {:?}", &weights);
121124

122-
// Remove self-weight by masking diagonal.
125+
// Remove self-weight by masking diagonal; keep owner_uid self-weight.
126+
if let Some(owner_uid) = owner_uid {
127+
inplace_mask_diag_except_index(&mut weights, owner_uid);
128+
} else {
129+
inplace_mask_diag(&mut weights);
130+
}
131+
123132
inplace_mask_diag(&mut weights);
124133
log::trace!("W (permit+diag):\n{:?}\n", &weights);
125134

pallets/subtensor/src/tests/math.rs

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,13 @@ fn assert_sparse_mat_compare(
5858
) {
5959
assert!(ma.len() == mb.len());
6060
for row in 0..ma.len() {
61-
assert!(ma[row].len() == mb[row].len());
61+
assert!(
62+
ma[row].len() == mb[row].len(),
63+
"row: {}, ma: {:?}, mb: {:?}",
64+
row,
65+
ma[row],
66+
mb[row]
67+
);
6268
for j in 0..ma[row].len() {
6369
assert!(ma[row][j].0 == mb[row][j].0); // u16
6470
assert_float_compare(ma[row][j].1, mb[row][j].1, epsilon) // I32F32
@@ -1034,6 +1040,27 @@ fn test_math_inplace_mask_diag() {
10341040
);
10351041
}
10361042

1043+
#[test]
1044+
fn test_math_inplace_mask_diag_except_index() {
1045+
let vector: Vec<f32> = vec![1., 2., 3., 4., 5., 6., 7., 8., 9.];
1046+
let rows = 3;
1047+
1048+
for i in 0..rows {
1049+
let mut target: Vec<f32> = vec![0., 2., 3., 4., 0., 6., 7., 8., 0.];
1050+
let row = i * rows;
1051+
let col = i;
1052+
target[row + col] = vector[row + col];
1053+
1054+
let mut mat = vec_to_mat_fixed(&vector, rows, false);
1055+
inplace_mask_diag_except_index(&mut mat, i as u16);
1056+
assert_mat_compare(
1057+
&mat,
1058+
&vec_to_mat_fixed(&target, rows, false),
1059+
I32F32::from_num(0),
1060+
);
1061+
}
1062+
}
1063+
10371064
#[test]
10381065
fn test_math_mask_rows_sparse() {
10391066
let input: Vec<f32> = vec![1., 2., 3., 4., 5., 6., 7., 8., 9.];
@@ -1105,6 +1132,58 @@ fn test_math_mask_diag_sparse() {
11051132
);
11061133
}
11071134

1135+
#[test]
1136+
fn test_math_mask_diag_sparse_except_index() {
1137+
let rows = 3;
1138+
1139+
let vector: Vec<f32> = vec![1., 2., 3., 4., 5., 6., 7., 8., 9.];
1140+
let mat = vec_to_sparse_mat_fixed(&vector, rows, false);
1141+
1142+
for i in 0..rows {
1143+
let mut target: Vec<f32> = vec![0., 2., 3., 4., 0., 6., 7., 8., 0.];
1144+
let row = i * rows;
1145+
let col = i;
1146+
target[row + col] = vector[row + col];
1147+
1148+
let result = mask_diag_sparse_except_index(&mat, i as u16);
1149+
let target_as_mat = vec_to_sparse_mat_fixed(&target, rows, false);
1150+
1151+
assert_sparse_mat_compare(&result, &target_as_mat, I32F32::from_num(0));
1152+
}
1153+
1154+
let vector: Vec<f32> = vec![1., 0., 0., 0., 5., 0., 0., 0., 9.];
1155+
let mat = vec_to_sparse_mat_fixed(&vector, rows, false);
1156+
1157+
for i in 0..rows {
1158+
let mut target: Vec<f32> = vec![0., 0., 0., 0., 0., 0., 0., 0., 0.];
1159+
let row = i * rows;
1160+
let col = i;
1161+
target[row + col] = vector[row + col];
1162+
1163+
let result = mask_diag_sparse_except_index(&mat, i as u16);
1164+
let target_as_mat = vec_to_sparse_mat_fixed(&target, rows, false);
1165+
assert_eq!(result.len(), target_as_mat.len());
1166+
1167+
assert_sparse_mat_compare(&result, &target_as_mat, I32F32::from_num(0));
1168+
}
1169+
1170+
for i in 0..rows {
1171+
let vector: Vec<f32> = vec![0., 0., 0., 0., 0., 0., 0., 0., 0.];
1172+
let mat = vec_to_sparse_mat_fixed(&vector, rows, false);
1173+
1174+
let mut target: Vec<f32> = vec![0., 0., 0., 0., 0., 0., 0., 0., 0.];
1175+
let row = i * rows;
1176+
let col = i;
1177+
target[row + col] = vector[row + col];
1178+
1179+
let result = mask_diag_sparse_except_index(&mat, i as u16);
1180+
let target_as_mat = vec_to_sparse_mat_fixed(&target, rows, false);
1181+
assert_eq!(result.len(), target_as_mat.len());
1182+
1183+
assert_sparse_mat_compare(&result, &target_as_mat, I32F32::from_num(0));
1184+
}
1185+
}
1186+
11081187
#[test]
11091188
fn test_math_vec_mask_sparse_matrix() {
11101189
let vector: Vec<f32> = vec![1., 2., 3., 4., 5., 6., 7., 8., 9.];

0 commit comments

Comments
 (0)