|
1 | 1 | import logging |
2 | 2 | import pandas as pd |
3 | 3 | import numpy as np |
4 | | -from typing import List, Dict |
| 4 | +from typing import Union, List, Dict |
5 | 5 |
|
6 | 6 | import skbio |
7 | 7 | from skbio.diversity import beta_diversity |
|
11 | 11 | from .utils import ( |
12 | 12 | check_index_names, |
13 | 13 | ) |
14 | | - |
| 14 | +from momics.constants import TAXONOMY_RANKS |
15 | 15 |
|
16 | 16 | # logger setup |
17 | 17 | FORMAT = "%(levelname)s | %(name)s | %(message)s" |
@@ -131,9 +131,65 @@ def calculate_shannon_index(df: pd.DataFrame) -> pd.Series: |
131 | 131 | return df.apply(shannon_index, axis=1) |
132 | 132 |
|
133 | 133 |
|
| 134 | +#################### |
| 135 | +# Search functions # |
| 136 | +#################### |
| 137 | +def find_taxa_in_table( |
| 138 | + table: pd.DataFrame, |
| 139 | + tax_level: str, |
| 140 | + search_term: Union[str, int], |
| 141 | + ncbi_tax_id: bool=False, |
| 142 | + exact_match:bool=False, |
| 143 | + ) -> pd.DataFrame: |
| 144 | + """ |
| 145 | + Find taxa in the given table at the specified taxonomic level matching the search term. |
| 146 | +
|
| 147 | + args: |
| 148 | + table (pd.DataFrame): DataFrame containing taxonomic data. |
| 149 | + tax_level (str): Taxonomic level to search ('all' for all levels). |
| 150 | + search_term (str|int): Term to search for. |
| 151 | + ncbi_tax_id (bool): If True, search by NCBI taxonomic ID. |
| 152 | + exact_match (bool): If True, perform exact match; otherwise, use substring match. |
| 153 | +
|
| 154 | + returns: |
| 155 | + pd.DataFrame: DataFrame containing matching taxa. |
| 156 | + """ |
| 157 | + # ncbi_tax_id search |
| 158 | + index_names = getattr(table.index, "names", []) |
| 159 | + if ncbi_tax_id and ('ncbi_tax_id' not in table.columns and 'ncbi_tax_id' not in index_names): |
| 160 | + raise ValueError("The table does not contain 'ncbi_tax_id' column or index level.") |
| 161 | + |
| 162 | + # if ncbi_tax_id is an index level, bring it into a column for uniform handling |
| 163 | + if ncbi_tax_id and ('ncbi_tax_id' in index_names): |
| 164 | + table = table.reset_index() |
| 165 | + |
| 166 | + if ncbi_tax_id: |
| 167 | + # Search by NCBI taxonomic ID |
| 168 | + matching_taxa = table[table['ncbi_tax_id'].astype(str) == str(search_term)] |
| 169 | + return matching_taxa.set_index(index_names) if index_names else matching_taxa |
| 170 | + |
| 171 | + # search by taxonomic level, all ranks |
| 172 | + if tax_level == 'all': |
| 173 | + found = [] |
| 174 | + for tax_level in TAXONOMY_RANKS: |
| 175 | + if exact_match: |
| 176 | + found.append(table[table[tax_level].str.lower().fillna('') == search_term.lower()]) |
| 177 | + else: |
| 178 | + found.append(table[table[tax_level].str.contains(search_term, case=False, na=False)]) |
| 179 | + matching_taxa = pd.concat(found) |
| 180 | + # specific taxonomic level |
| 181 | + else: |
| 182 | + if exact_match: |
| 183 | + matching_taxa = table[table[tax_level].str.lower().fillna('') == search_term.lower()] |
| 184 | + else: |
| 185 | + matching_taxa = table[table[tax_level].str.contains(search_term, case=False, na=False)] |
| 186 | + |
| 187 | + return matching_taxa |
| 188 | + |
134 | 189 | ####################### |
135 | 190 | # diversity functions # |
136 | 191 | ####################### |
| 192 | + |
137 | 193 | def calculate_alpha_diversity(df: pd.DataFrame, factors: pd.DataFrame) -> pd.DataFrame: |
138 | 194 | """ |
139 | 195 | Calculates the alpha diversity (Shannon index) for a DataFrame. |
|
0 commit comments