Skip to content

Commit 4b6217d

Browse files
author
Ian
committed
added new top n genes function
1 parent ca36c46 commit 4b6217d

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

src/sparse/csc.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::iter::Sum;
77
use std::ops::Add;
88
use std::ops::AddAssign;
99

10+
use crate::sparse::MatrixNTop;
1011
use crate::utils::Normalize;
1112

1213
use super::{
@@ -1027,6 +1028,41 @@ impl<M: NumericOps + NumCast> BatchMatrixMean for CscMatrix<M> {
10271028
}
10281029
}
10291030

1031+
impl<M: NumericOps + NumCast> MatrixNTop for CscMatrix<M> {
1032+
type Item = M;
1033+
1034+
fn sum_row_n_top<T>(&self, n: usize) -> anyhow::Result<Vec<T>>
1035+
where
1036+
T: Float + NumCast + AddAssign + Sum {
1037+
let mut result = vec![T::zero(); self.nrows()];
1038+
1039+
let mut row_values: Vec<Vec<T>> = vec![Vec::new(); self.nrows()];
1040+
1041+
for col_idx in 0..self.ncols() {
1042+
let col_start = self.col_offsets()[col_idx];
1043+
let col_end = self.col_offsets()[col_idx + 1];
1044+
1045+
for idx in col_start..col_end {
1046+
let row_idx = self.row_indices()[idx];
1047+
if let Some(val) = T::from(self.values()[idx]) {
1048+
row_values[row_idx].push(val);
1049+
}
1050+
}
1051+
}
1052+
1053+
for (row_idx, mut values) in row_values.into_iter().enumerate() {
1054+
if values.len() <= n {
1055+
result[row_idx] = values.into_iter().sum();
1056+
} else {
1057+
values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1058+
result[row_idx] = values.into_iter().take(n).sum();
1059+
}
1060+
}
1061+
1062+
Ok(result)
1063+
}
1064+
}
1065+
10301066
#[cfg(test)]
10311067
mod tests {
10321068
use Direction;

src/sparse/csr.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::ops::{Add, AddAssign};
66
use super::{
77
BatchMatrixMean, BatchMatrixVariance, MatrixMinMax, MatrixNonZero, MatrixSum, MatrixVariance,
88
};
9+
use crate::sparse::MatrixNTop;
910
use crate::utils::Normalize;
1011
use crate::utils::{BatchIdentifier, Log1P};
1112
use anyhow::{anyhow, Ok};
@@ -1033,6 +1034,36 @@ impl<M: NumericOps + NumCast> BatchMatrixMean for CsrMatrix<M> {
10331034
}
10341035
}
10351036

1037+
impl<M: NumericOps + NumCast> MatrixNTop for CsrMatrix<M> {
1038+
type Item = M;
1039+
1040+
fn sum_row_n_top<T>(&self, n: usize) -> anyhow::Result<Vec<T>>
1041+
where
1042+
T: Float + NumCast + AddAssign + Sum {
1043+
let mut result = vec![T::zero(); self.nrows()];
1044+
1045+
for row_idx in 0..self.nrows() {
1046+
let row_start = self.row_offsets()[row_idx];
1047+
let row_end = self.row_offsets()[row_idx + 1];
1048+
1049+
let mut row_values: Vec<T> = Vec::new();
1050+
for idx in row_start..row_end {
1051+
if let Some(val) = T::from(self.values()[idx]) {
1052+
row_values.push(val);
1053+
}
1054+
}
1055+
1056+
if row_values.len() <= n {
1057+
result[row_idx] = row_values.into_iter().sum();
1058+
} else {
1059+
row_values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1060+
result[row_idx] = row_values.into_iter().take(n).sum();
1061+
}
1062+
}
1063+
Ok(result)
1064+
}
1065+
}
1066+
10361067
#[cfg(test)]
10371068
mod tests {
10381069
use Direction;

src/sparse/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,4 +164,12 @@ pub trait BatchMatrixMean {
164164
B: BatchIdentifier;
165165
}
166166

167+
pub trait MatrixNTop {
168+
type Item: NumCast;
169+
170+
fn sum_row_n_top<T>(&self, n: usize) -> anyhow::Result<Vec<T>>
171+
where
172+
T: Float + NumCast + AddAssign + std::iter::Sum;
173+
}
174+
167175

0 commit comments

Comments
 (0)