@@ -197,7 +197,7 @@ def _write_output_file(
197197
198198def predict_ccs (
199199 psm_list_pred : PSMList ,
200- psm_list_cal : Optional [PSMList ] = None ,
200+ psm_list_cal : Optional [Union [ PSMList , pd . DataFrame ] ] = None ,
201201 file_reference : Optional [Union [str , Path ]] = None ,
202202 output_file : Optional [Union [str , Path ]] = None ,
203203 model_name : str = "tims" ,
@@ -222,8 +222,10 @@ def predict_ccs(
222222 psm_list_pred : PSMList
223223 PSM list containing peptides for CCS prediction. Each PSM should contain
224224 a valid peptidoform with sequence and modifications.
225- psm_list_cal : PSMList, optional
226- PSM list for calibration with observed CCS values in metadata.
225+ psm_list_cal : PSMList or pd.DataFrame, optional
226+ PSM list or DataFrame for calibration with observed CCS values.
227+ If PSMList: CCS values should be in metadata with key "CCS".
228+ If DataFrame: should have "ccs_observed" column.
227229 Required for calibration. Default is None (no calibration).
228230 file_reference : str or Path, optional
229231 Path to reference dataset file for calibration. Default uses built-in
@@ -356,10 +358,27 @@ def predict_ccs(
356358 if psm_list_cal is not None :
357359 try :
358360 LOGGER .info ("Applying calibration..." )
359- psm_list_cal_df = psm_list_cal .to_dataframe ()
360- psm_list_cal_df ["ccs_observed" ] = psm_list_cal_df ["metadata" ].apply (
361- lambda x : float (x .get ("CCS" )) if x and "CCS" in x else None
362- )
361+
362+ # Handle both PSMList and DataFrame input
363+ if isinstance (psm_list_cal , pd .DataFrame ):
364+ # Input is already a DataFrame with ccs_observed column
365+ psm_list_cal_df = psm_list_cal .copy ()
366+ if "ccs_observed" not in psm_list_cal_df .columns :
367+ raise IM2DeepError (
368+ "DataFrame calibration data must contain 'ccs_observed' column"
369+ )
370+ else :
371+ # Input is PSMList, extract CCS from metadata
372+ ccs_values = []
373+ for psm in psm_list_cal :
374+ if psm .metadata and "CCS" in psm .metadata :
375+ ccs_values .append (float (psm .metadata ["CCS" ]))
376+ else :
377+ ccs_values .append (None )
378+
379+ # Convert to DataFrame and add CCS values
380+ psm_list_cal_df = psm_list_cal .to_dataframe ()
381+ psm_list_cal_df ["ccs_observed" ] = ccs_values
363382
364383 # Filter out entries without CCS values
365384 psm_list_cal_df = psm_list_cal_df [psm_list_cal_df ["ccs_observed" ].notnull ()]
0 commit comments