@@ -58,7 +58,13 @@ fn assert_sparse_mat_compare(
58
58
) {
59
59
assert ! ( ma. len( ) == mb. len( ) ) ;
60
60
for row in 0 ..ma. len ( ) {
61
- assert ! ( ma[ row] . len( ) == mb[ row] . len( ) ) ;
61
+ assert ! (
62
+ ma[ row] . len( ) == mb[ row] . len( ) ,
63
+ "row: {}, ma: {:?}, mb: {:?}" ,
64
+ row,
65
+ ma[ row] ,
66
+ mb[ row]
67
+ ) ;
62
68
for j in 0 ..ma[ row] . len ( ) {
63
69
assert ! ( ma[ row] [ j] . 0 == mb[ row] [ j] . 0 ) ; // u16
64
70
assert_float_compare ( ma[ row] [ j] . 1 , mb[ row] [ j] . 1 , epsilon) // I32F32
@@ -1034,6 +1040,27 @@ fn test_math_inplace_mask_diag() {
1034
1040
) ;
1035
1041
}
1036
1042
1043
+ #[ test]
1044
+ fn test_math_inplace_mask_diag_except_index ( ) {
1045
+ let vector: Vec < f32 > = vec ! [ 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. ] ;
1046
+ let rows = 3 ;
1047
+
1048
+ for i in 0 ..rows {
1049
+ let mut target: Vec < f32 > = vec ! [ 0. , 2. , 3. , 4. , 0. , 6. , 7. , 8. , 0. ] ;
1050
+ let row = i * rows;
1051
+ let col = i;
1052
+ target[ row + col] = vector[ row + col] ;
1053
+
1054
+ let mut mat = vec_to_mat_fixed ( & vector, rows, false ) ;
1055
+ inplace_mask_diag_except_index ( & mut mat, i as u16 ) ;
1056
+ assert_mat_compare (
1057
+ & mat,
1058
+ & vec_to_mat_fixed ( & target, rows, false ) ,
1059
+ I32F32 :: from_num ( 0 ) ,
1060
+ ) ;
1061
+ }
1062
+ }
1063
+
1037
1064
#[ test]
1038
1065
fn test_math_mask_rows_sparse ( ) {
1039
1066
let input: Vec < f32 > = vec ! [ 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. ] ;
@@ -1105,6 +1132,58 @@ fn test_math_mask_diag_sparse() {
1105
1132
) ;
1106
1133
}
1107
1134
1135
+ #[ test]
1136
+ fn test_math_mask_diag_sparse_except_index ( ) {
1137
+ let rows = 3 ;
1138
+
1139
+ let vector: Vec < f32 > = vec ! [ 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. ] ;
1140
+ let mat = vec_to_sparse_mat_fixed ( & vector, rows, false ) ;
1141
+
1142
+ for i in 0 ..rows {
1143
+ let mut target: Vec < f32 > = vec ! [ 0. , 2. , 3. , 4. , 0. , 6. , 7. , 8. , 0. ] ;
1144
+ let row = i * rows;
1145
+ let col = i;
1146
+ target[ row + col] = vector[ row + col] ;
1147
+
1148
+ let result = mask_diag_sparse_except_index ( & mat, i as u16 ) ;
1149
+ let target_as_mat = vec_to_sparse_mat_fixed ( & target, rows, false ) ;
1150
+
1151
+ assert_sparse_mat_compare ( & result, & target_as_mat, I32F32 :: from_num ( 0 ) ) ;
1152
+ }
1153
+
1154
+ let vector: Vec < f32 > = vec ! [ 1. , 0. , 0. , 0. , 5. , 0. , 0. , 0. , 9. ] ;
1155
+ let mat = vec_to_sparse_mat_fixed ( & vector, rows, false ) ;
1156
+
1157
+ for i in 0 ..rows {
1158
+ let mut target: Vec < f32 > = vec ! [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ] ;
1159
+ let row = i * rows;
1160
+ let col = i;
1161
+ target[ row + col] = vector[ row + col] ;
1162
+
1163
+ let result = mask_diag_sparse_except_index ( & mat, i as u16 ) ;
1164
+ let target_as_mat = vec_to_sparse_mat_fixed ( & target, rows, false ) ;
1165
+ assert_eq ! ( result. len( ) , target_as_mat. len( ) ) ;
1166
+
1167
+ assert_sparse_mat_compare ( & result, & target_as_mat, I32F32 :: from_num ( 0 ) ) ;
1168
+ }
1169
+
1170
+ for i in 0 ..rows {
1171
+ let vector: Vec < f32 > = vec ! [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ] ;
1172
+ let mat = vec_to_sparse_mat_fixed ( & vector, rows, false ) ;
1173
+
1174
+ let mut target: Vec < f32 > = vec ! [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ] ;
1175
+ let row = i * rows;
1176
+ let col = i;
1177
+ target[ row + col] = vector[ row + col] ;
1178
+
1179
+ let result = mask_diag_sparse_except_index ( & mat, i as u16 ) ;
1180
+ let target_as_mat = vec_to_sparse_mat_fixed ( & target, rows, false ) ;
1181
+ assert_eq ! ( result. len( ) , target_as_mat. len( ) ) ;
1182
+
1183
+ assert_sparse_mat_compare ( & result, & target_as_mat, I32F32 :: from_num ( 0 ) ) ;
1184
+ }
1185
+ }
1186
+
1108
1187
#[ test]
1109
1188
fn test_math_vec_mask_sparse_matrix ( ) {
1110
1189
let vector: Vec < f32 > = vec ! [ 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9. ] ;
0 commit comments