55from tqdm import tqdm
66import argparse
77import torch
8+ import pandas as pd
89
910from MolecularDiffusion .utils import create_pyg_graph , correct_edges
1011from MolecularDiffusion .utils .geom_utils import read_xyz_file
@@ -159,7 +160,9 @@ def get_xtb_optimized_xyz(
159160 level : str = "gfn1" ,
160161 timeout : int = 240 ,
161162 scale_factor : float = 1.3 ,
162- optimize_all : bool = True
163+ optimize_all : bool = True ,
164+ csv_path : str = None ,
165+ filter_column : str = None
163166) -> list [str ]:
164167 """
165168 Optimizes all XYZ files in a given input directory using xTB and saves them
@@ -180,6 +183,8 @@ def get_xtb_optimized_xyz(
180183 timeout (int, optional): The maximum time in seconds to wait for each xTB process. Defaults to 240.
181184 scale_factor (float, optional): The scaling factor for covalent radii in edge correction. Defaults to 1.3.
182185 optimize_all (bool, optional): If True, optimizes all files regardless of existing optimized versions.
186+ csv_path (str, optional): Path to a CSV file to filter which XYZ files to optimize.
187+ filter_column (str, optional): The column name in the CSV to filter by (values must be 1).
183188
184189 Returns:
185190 list[str]: A list of paths to the successfully optimized XYZ files.
@@ -189,7 +194,49 @@ def get_xtb_optimized_xyz(
189194
190195 os .makedirs (output_directory , exist_ok = True )
191196
192- xyz_files = glob .glob (os .path .join (input_directory , "*.xyz" ))
197+ xyz_files = []
198+ if csv_path :
199+ if not os .path .exists (csv_path ):
200+ raise FileNotFoundError (f"CSV file not found: { csv_path } " )
201+
202+ df = pd .read_csv (csv_path )
203+
204+ fname_col = None
205+ for col in ["xyz_file" , "filename" , "filepath" ]:
206+ if col in df .columns :
207+ fname_col = col
208+ break
209+
210+ if fname_col is None :
211+ raise ValueError ("CSV must contain 'xyz_file', 'filename', or 'filepath' column." )
212+
213+ if filter_column :
214+ if filter_column not in df .columns :
215+ raise ValueError (f"Filter column '{ filter_column } ' not found in CSV." )
216+ # Filter rows where the value is 1 (as integer or string)
217+ filtered_df = df [df [filter_column ].isin (['1' , '1.0' , True , 1 ])]
218+ else :
219+ filtered_df = df
220+
221+ for _ , row in filtered_df .iterrows ():
222+ fname = str (row [fname_col ])
223+ # Handle potential missing extension if it's just a name
224+ if not fname .lower ().endswith ('.xyz' ):
225+ fname += '.xyz'
226+
227+ if os .path .isabs (fname ):
228+ full_path = fname
229+ else :
230+ full_path = os .path .join (input_directory , fname )
231+
232+ if os .path .exists (full_path ):
233+ xyz_files .append (full_path )
234+ else :
235+ print (f"Warning: File from CSV not found: { full_path } " )
236+
237+ else :
238+ xyz_files = glob .glob (os .path .join (input_directory , "*.xyz" ))
239+
193240 optimized_files = []
194241
195242 for xyz_file in tqdm (xyz_files , desc = "Optimizing XYZ files" , total = len (xyz_files )):
@@ -260,6 +307,18 @@ def get_xtb_optimized_xyz(
260307 default = 1.3 ,
261308 help = "Scaling factor for covalent radii in edge correction. Defaults to 1.3."
262309 )
310+ parser .add_argument (
311+ "--csv_path" ,
312+ type = str ,
313+ default = None ,
314+ help = "Path to CSV file for filtering which files to optimize."
315+ )
316+ parser .add_argument (
317+ "--filter_column" ,
318+ type = str ,
319+ default = None ,
320+ help = "Column name in CSV to filter by (values must be 1 to process)."
321+ )
263322
264323 args = parser .parse_args ()
265324
@@ -272,7 +331,9 @@ def get_xtb_optimized_xyz(
272331 charge = args .charge ,
273332 level = args .level ,
274333 timeout = args .timeout ,
275- scale_factor = args .scale_factor
334+ scale_factor = args .scale_factor ,
335+ csv_path = args .csv_path ,
336+ filter_column = args .filter_column
276337 )
277338
278339 print (f"Successfully optimized { len (optimized_files )} XYZ files and saved them in '{ output_dir } '." )
0 commit comments