Skip to content

Commit 6d45418

Browse files
committed
added functionality for no ontology columns or level 0
1 parent e583cb6 commit 6d45418

File tree

2 files changed

+55
-102
lines changed

2 files changed

+55
-102
lines changed

pages/01_Create_RDD_Count_Table.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,14 @@ def load_demo_file(filename):
252252
# -------- other parameters --------
253253
sample_type = st.selectbox("Reference sample type", ("all", "simple", "complex"))
254254
ontology_cols = st.text_input("Custom ontology columns (comma-separated)", "")
255-
levels_val = st.number_input("Maximum ontology levels to analyse", 1, 10, 6, 1)
255+
levels_val = st.number_input(
256+
"Maximum ontology levels to analyse",
257+
0,
258+
10,
259+
6,
260+
1,
261+
help="Set to 0 for file-level counts only (no ontology aggregation)",
262+
)
256263

257264
if ontology_cols and levels_val > len([c for c in ontology_cols.split(",") if c.strip()]):
258265
st.warning("Reducing 'levels' to match number of ontology columns.")

src/RDDcounts.py

Lines changed: 47 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,9 @@ def __init__(
6868
task_id: Optional[str] = None,
6969
gnps_2: bool = True,
7070
) -> None:
71-
7271
"""
7372
Initialize an RDDCounts instance and prepare all internal data structures used to compute RDD counts across ontology levels.
74-
73+
7574
Parameters:
7675
sample_types (str): Identifier or path used to load reference/sample ontology assignments.
7776
gnps_network_path (Optional[str]): Path to a GNPS network TSV; provide exactly one of this or `task_id`.
@@ -85,17 +84,15 @@ def __init__(
8584
blank_identifier (str): Reference label used as the blank/water baseline to subtract from counts (default "water").
8685
task_id (Optional[str]): GNPS task identifier to fetch network data; provide exactly one of this or `gnps_network_path`.
8786
gnps_2 (bool): Whether to fetch GNPS v2-formatted task data when `task_id` is used.
88-
87+
8988
Side effects:
9089
- Validates that exactly one of `task_id` or `gnps_network_path` is provided and raises ValueError otherwise.
9190
- Loads GNPS network, reference metadata, sample metadata, and sample type mappings.
9291
- Determines and validates ontology `levels` (auto-detects when `levels` is `None`).
9392
- Builds `ontology_table`, normalizes the GNPS network, and computes `file_level_counts` and aggregated `counts` across all taxonomy levels as instance attributes.
9493
"""
9594
if (task_id is None) == (gnps_network_path is None):
96-
raise ValueError(
97-
"Provide exactly one of task_id or gnps_network_path."
98-
)
95+
raise ValueError("Provide exactly one of task_id or gnps_network_path.")
9996
if task_id is None:
10097
self.raw_gnps_network = pd.read_csv(gnps_network_path, sep="\t")
10198
else:
@@ -107,34 +104,32 @@ def __init__(
107104
self.sample_group_col = sample_group_col
108105
self.blank_identifier = blank_identifier
109106

110-
self.reference_metadata = _load_RDD_metadata(
111-
external_reference_metadata
112-
)
107+
self.reference_metadata = _load_RDD_metadata(external_reference_metadata)
113108
self.sample_metadata = get_sample_metadata(
114109
self.raw_gnps_network,
115110
sample_groups,
116111
external_sample_metadata,
117112
filename_col="filename",
118113
)
119-
self.sample_types_df, self.ontology_columns_renamed = (
120-
_load_sample_types(
121-
self.reference_metadata,
122-
sample_types,
123-
ontology_columns=ontology_columns,
124-
)
114+
self.sample_types_df, self.ontology_columns_renamed = _load_sample_types(
115+
self.reference_metadata,
116+
sample_types,
117+
ontology_columns=ontology_columns,
125118
)
126119

127120
# Auto-determine levels if not specified
128121
if levels is None:
129122
self.levels = self._determine_ontology_levels()
130123
else:
131124
self.levels = levels
132-
# Validate that specified levels don't exceed available columns
125+
# Validate that specified levels are valid
126+
if self.levels < 0:
127+
raise ValueError(f"levels must be non-negative, got {self.levels}.")
133128
available_levels = self._determine_ontology_levels()
134129
if self.levels > available_levels:
135130
raise ValueError(
136131
f"levels ({self.levels}) exceeds available ontology columns "
137-
f"({available_levels})."
132+
f"({available_levels}). Use levels=0 for file-level counts only."
138133
)
139134

140135
self.ontology_table = (
@@ -155,36 +150,37 @@ def __init__(
155150
def _determine_ontology_levels(self) -> int:
156151
"""
157152
Determine how many ontology levels are available for RDD aggregation.
158-
153+
159154
Checks for a user-provided list of renamed ontology columns and returns its length if present;
160155
otherwise counts columns in `reference_metadata` that match the pattern `sample_type_group{n}`.
161-
156+
162157
Returns:
163158
int: Number of available ontology levels.
164159
"""
165160
if self.ontology_columns_renamed:
166161
return len(self.ontology_columns_renamed)
167-
162+
168163
# Count sample_type_groupX columns in reference metadata
169164
ontology_cols = [
170-
col for col in self.reference_metadata.columns
165+
col
166+
for col in self.reference_metadata.columns
171167
if re.match(r"sample_type_group\d+$", col)
172168
]
173169
return len(ontology_cols)
174170

175171
def _get_ontology_column_for_level(self, level: int) -> str:
176172
"""
177173
Return the ontology column name for the specified ontology level.
178-
174+
179175
If custom ontology columns were provided (stored in self.ontology_columns_renamed), the method returns
180176
the corresponding renamed column; otherwise it returns the default name "sample_type_group{level}".
181-
177+
182178
Parameters:
183179
level (int): 1-based ontology level to resolve.
184-
180+
185181
Returns:
186182
str: Column name corresponding to the given level.
187-
183+
188184
Raises:
189185
ValueError: If `self.ontology_columns_renamed` is set and `level` is outside the valid range
190186
(1..len(self.ontology_columns_renamed)).
@@ -242,9 +238,7 @@ def file_counts(
242238
sample_group_col=self.sample_group_col,
243239
reference_name_col=reference_name_col,
244240
)
245-
sample_clusters.drop_duplicates(
246-
subset=["filename", "cluster_index"], inplace=True
247-
)
241+
sample_clusters.drop_duplicates(subset=["filename", "cluster_index"], inplace=True)
248242
reference_clusters.drop_duplicates(
249243
subset=["filename", reference_name_col, "cluster_index"],
250244
inplace=True,
@@ -280,9 +274,7 @@ def file_counts(
280274
filename_to_group = self.sample_metadata.set_index("filename")[
281275
self.sample_group_col
282276
].to_dict()
283-
cluster_count_long["group"] = cluster_count_long["filename"].map(
284-
filename_to_group
285-
)
277+
cluster_count_long["group"] = cluster_count_long["filename"].map(filename_to_group)
286278
return cluster_count_long
287279

288280
def create_RDD_counts_all_levels(self) -> pd.DataFrame:
@@ -309,9 +301,7 @@ def create_RDD_counts_all_levels(self) -> pd.DataFrame:
309301
RDD_counts_file_level
310302
] # Initialize a list for storing data at all levels
311303
if "reference_type" not in RDD_counts_file_level.columns:
312-
raise ValueError(
313-
"Expected 'reference_type' column in file-level counts."
314-
)
304+
raise ValueError("Expected 'reference_type' column in file-level counts.")
315305
RDD_counts_file_level_sample_types = RDD_counts_file_level.merge(
316306
self.sample_types_df,
317307
left_on="reference_type",
@@ -327,9 +317,7 @@ def create_RDD_counts_all_levels(self) -> pd.DataFrame:
327317
ontology_col = self._get_ontology_column_for_level(level)
328318

329319
RDD_counts_level = (
330-
RDD_counts_file_level_sample_types.groupby(
331-
["filename", ontology_col]
332-
)["count"]
320+
RDD_counts_file_level_sample_types.groupby(["filename", ontology_col])["count"]
333321
.sum()
334322
.reset_index()
335323
)
@@ -346,21 +334,15 @@ def create_RDD_counts_all_levels(self) -> pd.DataFrame:
346334
columns_to_modify = wide_format_counts.columns.difference(
347335
["filename", self.blank_identifier]
348336
)
349-
wide_format_counts.loc[
337+
wide_format_counts.loc[:, columns_to_modify] = wide_format_counts.loc[
350338
:, columns_to_modify
351-
] = wide_format_counts.loc[:, columns_to_modify].where(
352-
wide_format_counts.loc[:, columns_to_modify].gt(
353-
water_counts, axis=0
354-
),
339+
].where(
340+
wide_format_counts.loc[:, columns_to_modify].gt(water_counts, axis=0),
355341
0,
356342
)
357-
wide_format_counts = wide_format_counts.drop(
358-
columns=[self.blank_identifier]
359-
)
343+
wide_format_counts = wide_format_counts.drop(columns=[self.blank_identifier])
360344

361-
wide_format_counts = wide_format_counts.loc[
362-
:, (wide_format_counts != 0).any(axis=0)
363-
]
345+
wide_format_counts = wide_format_counts.loc[:, (wide_format_counts != 0).any(axis=0)]
364346
if wide_format_counts.empty:
365347
continue # Skip this level
366348
RDD_counts_level = wide_format_counts.melt(
@@ -372,19 +354,13 @@ def create_RDD_counts_all_levels(self) -> pd.DataFrame:
372354

373355
RDD_counts_all_levels.append(RDD_counts_level)
374356

375-
RDD_counts_all_levels = pd.concat(
376-
RDD_counts_all_levels, ignore_index=True
377-
)
357+
RDD_counts_all_levels = pd.concat(RDD_counts_all_levels, ignore_index=True)
378358

379359
# Map group information from the sample_metadata to the final DataFrame
380-
RDD_counts_all_levels["group"] = RDD_counts_all_levels["filename"].map(
381-
sample_metadata_map
382-
)
360+
RDD_counts_all_levels["group"] = RDD_counts_all_levels["filename"].map(sample_metadata_map)
383361

384362
# Cast 'count' as an integer
385-
RDD_counts_all_levels["count"] = RDD_counts_all_levels["count"].astype(
386-
int
387-
)
363+
RDD_counts_all_levels["count"] = RDD_counts_all_levels["count"].astype(int)
388364

389365
return RDD_counts_all_levels
390366

@@ -440,19 +416,13 @@ def filter_counts(
440416
)
441417

442418
# Resolve level column names via helper to support both default and custom columns
443-
upper_ontology_col = self._get_ontology_column_for_level(
444-
upper_level
445-
)
446-
lower_ontology_col = self._get_ontology_column_for_level(
447-
lower_level
448-
)
419+
upper_ontology_col = self._get_ontology_column_for_level(upper_level)
420+
lower_ontology_col = self._get_ontology_column_for_level(lower_level)
449421

450422
# Matching lower-level reference types
451423
reference_types = (
452424
self.ontology_table[
453-
self.ontology_table[upper_ontology_col].isin(
454-
upper_level_reference_types
455-
)
425+
self.ontology_table[upper_ontology_col].isin(upper_level_reference_types)
456426
][lower_ontology_col]
457427
.dropna()
458428
.unique()
@@ -469,9 +439,7 @@ def filter_counts(
469439
if sample_names:
470440
if isinstance(sample_names, str):
471441
sample_names = [sample_names]
472-
filtered_df = filtered_df[
473-
filtered_df["filename"].isin(sample_names)
474-
]
442+
filtered_df = filtered_df[filtered_df["filename"].isin(sample_names)]
475443

476444
# Filter by group(s)
477445
if group is not None:
@@ -481,24 +449,18 @@ def filter_counts(
481449

482450
# Filter by explicitly provided reference_types first (before top_n)
483451
if reference_types is not None:
484-
filtered_df = filtered_df[
485-
filtered_df["reference_type"].isin(reference_types)
486-
]
452+
filtered_df = filtered_df[filtered_df["reference_type"].isin(reference_types)]
487453

488454
# Select top N reference types if requested (only if reference_types not explicitly provided)
489455
if top_n is not None and reference_types is None:
490456
if top_n_method == "per_sample":
491457
top_df = (
492-
filtered_df.sort_values(
493-
["filename", "count"], ascending=[True, False]
494-
)
458+
filtered_df.sort_values(["filename", "count"], ascending=[True, False])
495459
.groupby("filename")
496460
.head(top_n)
497461
.reset_index(drop=True)
498462
)
499-
top_reference_types = (
500-
top_df["reference_type"].dropna().unique().tolist()
501-
)
463+
top_reference_types = top_df["reference_type"].dropna().unique().tolist()
502464

503465
elif top_n_method == "total":
504466
top_reference_types = (
@@ -521,9 +483,7 @@ def filter_counts(
521483
"Invalid top_n_method. Choose from 'per_sample', 'total', or 'average'."
522484
)
523485

524-
filtered_df = filtered_df[
525-
filtered_df["reference_type"].isin(top_reference_types)
526-
]
486+
filtered_df = filtered_df[filtered_df["reference_type"].isin(top_reference_types)]
527487

528488
return filtered_df
529489

@@ -560,9 +520,7 @@ def update_groups(
560520

561521
# Update the 'group' column in counts using the mapping
562522
self.counts["group"] = (
563-
self.counts["group"]
564-
.map(group_mapping)
565-
.fillna(self.counts["group"])
523+
self.counts["group"].map(group_mapping).fillna(self.counts["group"])
566524
)
567525

568526
# Update the sample_metadata using the sample_group_col
@@ -597,15 +555,11 @@ def update_groups(
597555
)
598556

599557
# Create a mapping from filename to new group
600-
filename_to_group = metadata.set_index("filename")[
601-
merge_column
602-
].to_dict()
558+
filename_to_group = metadata.set_index("filename")[merge_column].to_dict()
603559

604560
# Update the 'group' column in counts
605561
self.counts["group"] = (
606-
self.counts["filename"]
607-
.map(filename_to_group)
608-
.fillna(self.counts["group"])
562+
self.counts["filename"].map(filename_to_group).fillna(self.counts["group"])
609563
)
610564

611565
# Update the sample_metadata
@@ -639,11 +593,7 @@ def generate_RDDflows(
639593
- processes: DataFrame with unique nodes across the levels.
640594
"""
641595
# Use provided max_hierarchy_level or default to the instance's levels
642-
max_level = (
643-
max_hierarchy_level
644-
if max_hierarchy_level is not None
645-
else self.levels
646-
)
596+
max_level = max_hierarchy_level if max_hierarchy_level is not None else self.levels
647597

648598
# Filter counts by filename if a filter is specified
649599
if filename_filter:
@@ -704,16 +654,12 @@ def generate_RDDflows(
704654
flows_df = pd.concat(flows, ignore_index=True)
705655

706656
# Build processes from unique nodes in flows
707-
all_nodes = (
708-
pd.concat([flows_df["source"], flows_df["target"]])
709-
.dropna()
710-
.unique()
711-
)
657+
all_nodes = pd.concat([flows_df["source"], flows_df["target"]]).dropna().unique()
712658
processes_df = pd.DataFrame(
713659
{
714660
"id": all_nodes,
715661
"level": [int(node.split("_")[-1]) for node in all_nodes],
716662
}
717663
).set_index("id")
718664

719-
return flows_df, processes_df
665+
return flows_df, processes_df

0 commit comments

Comments
 (0)