|
22 | 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
23 | 23 |
|
24 | 24 | parser = argparse.ArgumentParser(description='Encode pdbs') |
25 | | - parser.add_argument('input_dir', type=str, help='Input directory with pdbs') |
26 | | - parser.add_argument('input_glob', type=str, help='Glob pattern for input pdbs') |
| 25 | + parser.add_argument('input_path', type=str, help='Input directory with pdbs or glob pattern (e.g., /path/to/pdbs or "/path/**/*.pdb")') |
27 | 26 | parser.add_argument('output_h5', type=str, help='Output file with pytorch geometric graphs of pdbs') |
28 | 27 | parser.add_argument('foldxdir', type=str, nargs='?', default=None, help='foldx directory with foldx output for all pdbs') |
29 | 28 | parser.add_argument('--distance', type=float, default=15, help='Distance threshold for contact map (default: 15)') |
30 | 29 | parser.add_argument('--add-prody', action='store_true', default=True, help='Add ProDy features (default: True)') |
31 | 30 | parser.add_argument('--verbose', action='store_true', default=False, help='Verbose output') |
32 | 31 | parser.add_argument('--multiprocessing', action='store_true', default=False, help='Use multiprocessing for parallel processing') |
33 | 32 | parser.add_argument('--ncpu', type=int, default=25, help='Number of CPUs for multiprocessing (default: 25)') |
| 33 | + parser.add_argument('--nstructs', type=int, default=None, help='Number of structures to use (random subsample if specified)') |
34 | 34 |
|
35 | 35 | # Add help for the arguments |
36 | 36 | parser.description = "Encode PDB files into PyTorch geometric graphs with optional FoldX data integration." |
37 | 37 | parser.epilog = ("Example usage:\n" |
38 | | - " python encode_pdbs.py /path/to/pdbs '*.pdb' output.h5 /path/to/foldx") |
| 38 | + " python encode_pdbs.py /path/to/pdbs output.h5\n" |
| 39 | + " python encode_pdbs.py '/path/**/*.pdb' output.h5 /path/to/foldx") |
39 | 40 |
|
40 | 41 | args = parser.parse_args() |
41 | 42 |
|
42 | | - if args.input_glob: |
43 | | - files = glob.glob(args.input_glob) |
| 43 | + # Handle input path - can be directory or glob pattern |
| 44 | + if os.path.isdir(args.input_path): |
| 45 | + # It's a directory, find all PDB files in it |
| 46 | + files = glob.glob(os.path.join(args.input_path, '*.pdb')) |
| 47 | + input_source = args.input_path |
44 | 48 | else: |
45 | | - files = glob.glob(os.path.join(args.input_dir, '*.pdb')) |
| 49 | + # It's a glob pattern |
| 50 | + files = glob.glob(args.input_path, recursive=True) |
| 51 | + input_source = args.input_path |
46 | 52 |
|
| 53 | + print(f"Found {len(files)} PDB files from {input_source}") |
47 | 54 | # Shuffle the data for randomization |
48 | 55 | np.random.shuffle(files) |
49 | 56 |
|
| 57 | + # Subsample if nstructs is specified |
| 58 | + if args.nstructs is not None: |
| 59 | + files = files[:args.nstructs] |
| 60 | + |
50 | 61 | output_h5 = args.output_h5 |
51 | 62 | foldx = args.foldxdir |
52 | 63 |
|
|
0 commit comments