@@ -68,10 +68,9 @@ def __init__(
6868 task_id : Optional [str ] = None ,
6969 gnps_2 : bool = True ,
7070 ) -> None :
71-
7271 """
7372 Initialize an RDDCounts instance and prepare all internal data structures used to compute RDD counts across ontology levels.
74-
73+
7574 Parameters:
7675 sample_types (str): Identifier or path used to load reference/sample ontology assignments.
7776 gnps_network_path (Optional[str]): Path to a GNPS network TSV; provide exactly one of this or `task_id`.
@@ -85,17 +84,15 @@ def __init__(
8584 blank_identifier (str): Reference label used as the blank/water baseline to subtract from counts (default "water").
8685 task_id (Optional[str]): GNPS task identifier to fetch network data; provide exactly one of this or `gnps_network_path`.
8786 gnps_2 (bool): Whether to fetch GNPS v2-formatted task data when `task_id` is used.
88-
87+
8988 Side effects:
9089 - Validates that exactly one of `task_id` or `gnps_network_path` is provided and raises ValueError otherwise.
9190 - Loads GNPS network, reference metadata, sample metadata, and sample type mappings.
9291 - Determines and validates ontology `levels` (auto-detects when `levels` is `None`).
9392 - Builds `ontology_table`, normalizes the GNPS network, and computes `file_level_counts` and aggregated `counts` across all taxonomy levels as instance attributes.
9493 """
9594 if (task_id is None ) == (gnps_network_path is None ):
96- raise ValueError (
97- "Provide exactly one of task_id or gnps_network_path."
98- )
95+ raise ValueError ("Provide exactly one of task_id or gnps_network_path." )
9996 if task_id is None :
10097 self .raw_gnps_network = pd .read_csv (gnps_network_path , sep = "\t " )
10198 else :
@@ -107,34 +104,32 @@ def __init__(
107104 self .sample_group_col = sample_group_col
108105 self .blank_identifier = blank_identifier
109106
110- self .reference_metadata = _load_RDD_metadata (
111- external_reference_metadata
112- )
107+ self .reference_metadata = _load_RDD_metadata (external_reference_metadata )
113108 self .sample_metadata = get_sample_metadata (
114109 self .raw_gnps_network ,
115110 sample_groups ,
116111 external_sample_metadata ,
117112 filename_col = "filename" ,
118113 )
119- self .sample_types_df , self .ontology_columns_renamed = (
120- _load_sample_types (
121- self .reference_metadata ,
122- sample_types ,
123- ontology_columns = ontology_columns ,
124- )
114+ self .sample_types_df , self .ontology_columns_renamed = _load_sample_types (
115+ self .reference_metadata ,
116+ sample_types ,
117+ ontology_columns = ontology_columns ,
125118 )
126119
127120 # Auto-determine levels if not specified
128121 if levels is None :
129122 self .levels = self ._determine_ontology_levels ()
130123 else :
131124 self .levels = levels
132- # Validate that specified levels don't exceed available columns
125+ # Validate that specified levels are valid
126+ if self .levels < 0 :
127+ raise ValueError (f"levels must be non-negative, got { self .levels } ." )
133128 available_levels = self ._determine_ontology_levels ()
134129 if self .levels > available_levels :
135130 raise ValueError (
136131 f"levels ({ self .levels } ) exceeds available ontology columns "
137- f"({ available_levels } )."
132+ f"({ available_levels } ). Use levels=0 for file-level counts only. "
138133 )
139134
140135 self .ontology_table = (
@@ -155,36 +150,37 @@ def __init__(
155150 def _determine_ontology_levels (self ) -> int :
156151 """
157152 Determine how many ontology levels are available for RDD aggregation.
158-
153+
159154 Checks for a user-provided list of renamed ontology columns and returns its length if present;
160155 otherwise counts columns in `reference_metadata` that match the pattern `sample_type_group{n}`.
161-
156+
162157 Returns:
163158 int: Number of available ontology levels.
164159 """
165160 if self .ontology_columns_renamed :
166161 return len (self .ontology_columns_renamed )
167-
162+
168163 # Count sample_type_groupX columns in reference metadata
169164 ontology_cols = [
170- col for col in self .reference_metadata .columns
165+ col
166+ for col in self .reference_metadata .columns
171167 if re .match (r"sample_type_group\d+$" , col )
172168 ]
173169 return len (ontology_cols )
174170
175171 def _get_ontology_column_for_level (self , level : int ) -> str :
176172 """
177173 Return the ontology column name for the specified ontology level.
178-
174+
179175 If custom ontology columns were provided (stored in self.ontology_columns_renamed), the method returns
180176 the corresponding renamed column; otherwise it returns the default name "sample_type_group{level}".
181-
177+
182178 Parameters:
183179 level (int): 1-based ontology level to resolve.
184-
180+
185181 Returns:
186182 str: Column name corresponding to the given level.
187-
183+
188184 Raises:
189185 ValueError: If `self.ontology_columns_renamed` is set and `level` is outside the valid range
190186 (1..len(self.ontology_columns_renamed)).
@@ -242,9 +238,7 @@ def file_counts(
242238 sample_group_col = self .sample_group_col ,
243239 reference_name_col = reference_name_col ,
244240 )
245- sample_clusters .drop_duplicates (
246- subset = ["filename" , "cluster_index" ], inplace = True
247- )
241+ sample_clusters .drop_duplicates (subset = ["filename" , "cluster_index" ], inplace = True )
248242 reference_clusters .drop_duplicates (
249243 subset = ["filename" , reference_name_col , "cluster_index" ],
250244 inplace = True ,
@@ -280,9 +274,7 @@ def file_counts(
280274 filename_to_group = self .sample_metadata .set_index ("filename" )[
281275 self .sample_group_col
282276 ].to_dict ()
283- cluster_count_long ["group" ] = cluster_count_long ["filename" ].map (
284- filename_to_group
285- )
277+ cluster_count_long ["group" ] = cluster_count_long ["filename" ].map (filename_to_group )
286278 return cluster_count_long
287279
288280 def create_RDD_counts_all_levels (self ) -> pd .DataFrame :
@@ -309,9 +301,7 @@ def create_RDD_counts_all_levels(self) -> pd.DataFrame:
309301 RDD_counts_file_level
310302 ] # Initialize a list for storing data at all levels
311303 if "reference_type" not in RDD_counts_file_level .columns :
312- raise ValueError (
313- "Expected 'reference_type' column in file-level counts."
314- )
304+ raise ValueError ("Expected 'reference_type' column in file-level counts." )
315305 RDD_counts_file_level_sample_types = RDD_counts_file_level .merge (
316306 self .sample_types_df ,
317307 left_on = "reference_type" ,
@@ -327,9 +317,7 @@ def create_RDD_counts_all_levels(self) -> pd.DataFrame:
327317 ontology_col = self ._get_ontology_column_for_level (level )
328318
329319 RDD_counts_level = (
330- RDD_counts_file_level_sample_types .groupby (
331- ["filename" , ontology_col ]
332- )["count" ]
320+ RDD_counts_file_level_sample_types .groupby (["filename" , ontology_col ])["count" ]
333321 .sum ()
334322 .reset_index ()
335323 )
@@ -346,21 +334,15 @@ def create_RDD_counts_all_levels(self) -> pd.DataFrame:
346334 columns_to_modify = wide_format_counts .columns .difference (
347335 ["filename" , self .blank_identifier ]
348336 )
349- wide_format_counts .loc [
337+ wide_format_counts .loc [:, columns_to_modify ] = wide_format_counts . loc [
350338 :, columns_to_modify
351- ] = wide_format_counts .loc [:, columns_to_modify ].where (
352- wide_format_counts .loc [:, columns_to_modify ].gt (
353- water_counts , axis = 0
354- ),
339+ ].where (
340+ wide_format_counts .loc [:, columns_to_modify ].gt (water_counts , axis = 0 ),
355341 0 ,
356342 )
357- wide_format_counts = wide_format_counts .drop (
358- columns = [self .blank_identifier ]
359- )
343+ wide_format_counts = wide_format_counts .drop (columns = [self .blank_identifier ])
360344
361- wide_format_counts = wide_format_counts .loc [
362- :, (wide_format_counts != 0 ).any (axis = 0 )
363- ]
345+ wide_format_counts = wide_format_counts .loc [:, (wide_format_counts != 0 ).any (axis = 0 )]
364346 if wide_format_counts .empty :
365347 continue # Skip this level
366348 RDD_counts_level = wide_format_counts .melt (
@@ -372,19 +354,13 @@ def create_RDD_counts_all_levels(self) -> pd.DataFrame:
372354
373355 RDD_counts_all_levels .append (RDD_counts_level )
374356
375- RDD_counts_all_levels = pd .concat (
376- RDD_counts_all_levels , ignore_index = True
377- )
357+ RDD_counts_all_levels = pd .concat (RDD_counts_all_levels , ignore_index = True )
378358
379359 # Map group information from the sample_metadata to the final DataFrame
380- RDD_counts_all_levels ["group" ] = RDD_counts_all_levels ["filename" ].map (
381- sample_metadata_map
382- )
360+ RDD_counts_all_levels ["group" ] = RDD_counts_all_levels ["filename" ].map (sample_metadata_map )
383361
384362 # Cast 'count' as an integer
385- RDD_counts_all_levels ["count" ] = RDD_counts_all_levels ["count" ].astype (
386- int
387- )
363+ RDD_counts_all_levels ["count" ] = RDD_counts_all_levels ["count" ].astype (int )
388364
389365 return RDD_counts_all_levels
390366
@@ -440,19 +416,13 @@ def filter_counts(
440416 )
441417
442418 # Resolve level column names via helper to support both default and custom columns
443- upper_ontology_col = self ._get_ontology_column_for_level (
444- upper_level
445- )
446- lower_ontology_col = self ._get_ontology_column_for_level (
447- lower_level
448- )
419+ upper_ontology_col = self ._get_ontology_column_for_level (upper_level )
420+ lower_ontology_col = self ._get_ontology_column_for_level (lower_level )
449421
450422 # Matching lower-level reference types
451423 reference_types = (
452424 self .ontology_table [
453- self .ontology_table [upper_ontology_col ].isin (
454- upper_level_reference_types
455- )
425+ self .ontology_table [upper_ontology_col ].isin (upper_level_reference_types )
456426 ][lower_ontology_col ]
457427 .dropna ()
458428 .unique ()
@@ -469,9 +439,7 @@ def filter_counts(
469439 if sample_names :
470440 if isinstance (sample_names , str ):
471441 sample_names = [sample_names ]
472- filtered_df = filtered_df [
473- filtered_df ["filename" ].isin (sample_names )
474- ]
442+ filtered_df = filtered_df [filtered_df ["filename" ].isin (sample_names )]
475443
476444 # Filter by group(s)
477445 if group is not None :
@@ -481,24 +449,18 @@ def filter_counts(
481449
482450 # Filter by explicitly provided reference_types first (before top_n)
483451 if reference_types is not None :
484- filtered_df = filtered_df [
485- filtered_df ["reference_type" ].isin (reference_types )
486- ]
452+ filtered_df = filtered_df [filtered_df ["reference_type" ].isin (reference_types )]
487453
488454 # Select top N reference types if requested (only if reference_types not explicitly provided)
489455 if top_n is not None and reference_types is None :
490456 if top_n_method == "per_sample" :
491457 top_df = (
492- filtered_df .sort_values (
493- ["filename" , "count" ], ascending = [True , False ]
494- )
458+ filtered_df .sort_values (["filename" , "count" ], ascending = [True , False ])
495459 .groupby ("filename" )
496460 .head (top_n )
497461 .reset_index (drop = True )
498462 )
499- top_reference_types = (
500- top_df ["reference_type" ].dropna ().unique ().tolist ()
501- )
463+ top_reference_types = top_df ["reference_type" ].dropna ().unique ().tolist ()
502464
503465 elif top_n_method == "total" :
504466 top_reference_types = (
@@ -521,9 +483,7 @@ def filter_counts(
521483 "Invalid top_n_method. Choose from 'per_sample', 'total', or 'average'."
522484 )
523485
524- filtered_df = filtered_df [
525- filtered_df ["reference_type" ].isin (top_reference_types )
526- ]
486+ filtered_df = filtered_df [filtered_df ["reference_type" ].isin (top_reference_types )]
527487
528488 return filtered_df
529489
@@ -560,9 +520,7 @@ def update_groups(
560520
561521 # Update the 'group' column in counts using the mapping
562522 self .counts ["group" ] = (
563- self .counts ["group" ]
564- .map (group_mapping )
565- .fillna (self .counts ["group" ])
523+ self .counts ["group" ].map (group_mapping ).fillna (self .counts ["group" ])
566524 )
567525
568526 # Update the sample_metadata using the sample_group_col
@@ -597,15 +555,11 @@ def update_groups(
597555 )
598556
599557 # Create a mapping from filename to new group
600- filename_to_group = metadata .set_index ("filename" )[
601- merge_column
602- ].to_dict ()
558+ filename_to_group = metadata .set_index ("filename" )[merge_column ].to_dict ()
603559
604560 # Update the 'group' column in counts
605561 self .counts ["group" ] = (
606- self .counts ["filename" ]
607- .map (filename_to_group )
608- .fillna (self .counts ["group" ])
562+ self .counts ["filename" ].map (filename_to_group ).fillna (self .counts ["group" ])
609563 )
610564
611565 # Update the sample_metadata
@@ -639,11 +593,7 @@ def generate_RDDflows(
639593 - processes: DataFrame with unique nodes across the levels.
640594 """
641595 # Use provided max_hierarchy_level or default to the instance's levels
642- max_level = (
643- max_hierarchy_level
644- if max_hierarchy_level is not None
645- else self .levels
646- )
596+ max_level = max_hierarchy_level if max_hierarchy_level is not None else self .levels
647597
648598 # Filter counts by filename if a filter is specified
649599 if filename_filter :
@@ -704,16 +654,12 @@ def generate_RDDflows(
704654 flows_df = pd .concat (flows , ignore_index = True )
705655
706656 # Build processes from unique nodes in flows
707- all_nodes = (
708- pd .concat ([flows_df ["source" ], flows_df ["target" ]])
709- .dropna ()
710- .unique ()
711- )
657+ all_nodes = pd .concat ([flows_df ["source" ], flows_df ["target" ]]).dropna ().unique ()
712658 processes_df = pd .DataFrame (
713659 {
714660 "id" : all_nodes ,
715661 "level" : [int (node .split ("_" )[- 1 ]) for node in all_nodes ],
716662 }
717663 ).set_index ("id" )
718664
719- return flows_df , processes_df
665+ return flows_df , processes_df
0 commit comments