-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprocess_mtd_traj.py
More file actions
312 lines (271 loc) · 15 KB
/
process_mtd_traj.py
File metadata and controls
312 lines (271 loc) · 15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import argparse
import h5py
import mdtraj as md
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid', palette='deep')
sns.set_context(context='paper', font_scale=2.2)
def get_pair_list(topology) -> list[np.ndarray]:
"""
Get the atom pairs for the WC, HG and BP distances in the A7-T37 base pair of DNA.
Atoms involved in distances:
DA7 N1 -- DT37 N3
DA7 N7 -- DT37 N3
DA7 N6 -- DT37 O4
WC to HG is A7 -- T37,
T6 A38
A8 T36
"""
atomsWC = topology.select('name N1 and resid 6 or name N3 and resid 16')
atomsHG = topology.select('name N7 and resid 6 or name N3 and resid 16')
atomsBP = topology.select('name N6 and resid 6 or name O4 and resid 16')
pairs = [atomsWC, atomsHG, atomsBP]
# for pair in pairs:
# atom1 = topology.atom(pair[0])
# atom2 = topology.atom(pair[1])
# print('''%s in %s%s -- %s in %s%s''' % (atom1.name, atom1.residue.index, atom1.residue.name,
# atom2.name, atom2.residue.index, atom2.residue.name))
return pairs
def get_bp_distances(trajectory, pairs) -> np.ndarray:
"""
Calculate the basepair distances and order parameter for a given trajectory.
1. WC distance (N1-N3)
2. HG distance (N7-N3)
3. BP distance (N6-O4)
4. Order parameter (lambda) = arctan2(d_HG, d_WC)
5. Return the distances and order parameter as a numpy array.
:param trajectory:
:param pairs:
:return:
"""
# calculate distances and order parameter for stable state runs
# WC
bpdist = md.compute_distances(trajectory, atom_pairs=pairs, periodic=True)
# print(bpdist)
opwc = np.zeros(trajectory.n_frames)
for f in range(trajectory.n_frames):
opwc[f] = np.arctan2(bpdist[f, 0], bpdist[f, 1])
return opwc
def plot_traj(traj_file, dir_path, out_name, new_path, sys_name) -> None:
"""
Plot trajectory - handles both file paths and already loaded trajectories.
:param traj_file: str or md.Trajectory, trajectory file path or already loaded trajectory
:param dir_path: Path, directory containing the trajectory file (if traj_file is a filename)
:param out_name: str, base name for output plot (without extension)
:param new_path: Path, directory to store output plot (default: same as dir_path)
:param sys_name: str, system name for plot title (default: empty)
:return: None. Automatically saves the plot in the specified directory.
"""
print(f'System name for plot title: {sys_name if sys_name else "None"}')
if isinstance(traj_file, str):
# traj_file is a filename, need to load it
if dir_path:
traj = md.load_hdf5(Path(dir_path) / traj_file)
else:
traj = md.load_hdf5(traj_file) # Assume full path
elif hasattr(traj_file, 'topology'):
# traj_file is already a loaded MDTraj trajectory object
traj = traj_file
else:
# traj_file is likely a h5py File object or file path
if hasattr(traj_file, 'filename'):
# It's a h5py File object, get the filename and load with MDTraj
traj = md.load_hdf5(traj_file.filename)
else:
raise ValueError(f"Cannot handle traj_file of type {type(traj_file)}")
if not new_path:
new_path = dir_path
pairs = get_pair_list(traj.topology)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(16, 9), dpi=300)
if sys_name:
print(f'System name found: {sys_name if sys_name else "None"}')
ax.set(xlabel='Frames', ylabel=r'$\mathrm{distance\ (\AA)}$', title=sys_name)
else:
ax.set(xlabel='Frames', ylabel=r'$\mathrm{distance\ (\AA)}$')
ax.plot(get_bp_distances(traj, pairs), color='b')
ax.set_ylim((0.38, 1.20))
plt.margins(x=0, y=0)
plt.tight_layout()
plt.savefig(new_path / f'{out_name}.png')
def concatenate_slices(slices, dir_path, out_path, out_name, sys_name) -> None:
"""
Concatenate multiple trajectory slices and automatically plot the result.
:param slices: list of str or Path, list of trajectory file paths to concatenate
:param dir_path: Path, directory containing the trajectory files
:param out_path: Path, directory to store output file (default: same as dir_path)
:param out_name: str, base name for output file (without extension)
:param sys_name: str, system name for plot title (default: empty)
:return: None, automatically saves the concatenated trajectory and plot in the specified directory.
"""
print(f'System name for plot title: {sys_name if sys_name else "None"}')
if out_path is None:
out_path = dir_path
output_file = Path(out_path) / f'{out_name}.h5'
try:
with h5py.File(output_file, 'w') as outfile:
for idx, slice_file in enumerate(slices):
print(f"Processing file {idx + 1}/{len(slices)}: {slice_file}")
# Handle both full paths and just filenames
if isinstance(slice_file, str):
slice_path = Path(slice_file)
if not slice_path.is_file():
slice_path = Path(dir_path) / slice_file
else:
slice_path = slice_file
with h5py.File(slice_path, "r") as trajectory:
if idx == 0:
# First file: create all datasets
for key in trajectory.keys():
if key == 'topology':
data = trajectory[key]
new_set = outfile.create_dataset(key, data=data)
for attr in trajectory[key].attrs.keys():
new_set.attrs[attr] = trajectory[key].attrs[attr]
else:
data = trajectory[key]
maxshape = (None,) + data.shape[1:]
new_set = outfile.create_dataset(key, data=data, chunks=True, maxshape=maxshape)
for attr in trajectory[key].attrs.keys():
new_set.attrs[attr] = trajectory[key].attrs[attr]
else:
# Subsequent files: append data
for key in trajectory.keys():
if key != 'topology':
if key in outfile:
data = trajectory[key]
outfile[key].resize((outfile[key].shape[0] + data.shape[0]), axis=0)
outfile[key][-data.shape[0]:] = data
else:
print(f" Warning: Dataset '{key}' not found in output file, skipping")
pass
print(f"Successfully concatenated {len(slices)} files into {output_file}")
# Automatically plot the concatenated trajectory
print("Generating plot for concatenated trajectory...")
plot_traj(str(output_file), None, out_name, out_path, sys_name)
print(f"Plot saved as {Path(out_path) / f'{out_name}.png'}")
except Exception as e:
print(f'Error during concatenation: {e}')
def slice_trajectory(filename, new_name, dir_path, new_path, start, stop, count, sys_name) -> None:
"""
Slice a trajectory file into either a single segment (start to stop) or multiple 10000-frame chunks (count).
Automatically plots each resulting slice.
Only one of (start/stop) or count should be provided.
:param filename: str, name of the MDTraj HDF5 trajectory file to slice (without .h5 extension)
:param new_name: str, base name for output files (without .h5 extension)
:param dir_path: Path, directory containing the trajectory file
:param new_path: Path, directory to store output files (default: same as dir_path)
:param start: int, starting frame number for single slice (inclusive), None if using count
:param stop: int, ending frame number for single slice (exclusive), None if using count
:param count: int, number of 10000-frame chunks to create, None if using start/stop
:param sys_name: str, system name for plot title (default: empty)
:return: None
"""
if new_path is None:
new_path = dir_path
print(f'System name for plot title: {sys_name if sys_name else "None"}')
def slice_h5traj(whole_traj, sliced_traj, first_frame, last_frame) -> None:
for key in whole_traj.keys():
if key == 'topology':
data = whole_traj[key]
new_set = sliced_traj.create_dataset(key, data=data)
for attr in whole_traj[key].attrs.keys():
new_set.attrs[attr] = whole_traj[key].attrs[attr]
else:
data = whole_traj[key][first_frame:last_frame]
new_set = sliced_traj.create_dataset(key, data=data)
for attr in whole_traj[key].attrs.keys():
new_set.attrs[attr] = whole_traj[key].attrs[attr]
with h5py.File(Path(dir_path) / f'{filename}.h5', "r") as trajectory:
if start is not None and stop is not None and not count:
output_file = Path(new_path) / f'{new_name}.h5'
try:
with h5py.File(output_file, 'w') as sliced_trajectory:
slice_h5traj(trajectory, sliced_trajectory, start, stop)
# Now plot using the saved file path
plot_traj(str(output_file), None, new_name, new_path, sys_name)
except Exception as e:
print(f'Error creating sliced file: {e}')
elif count and start is None and stop is None:
print(f'Slicing trajectory into {count} chunks...')
for cnt in range(count):
start_frame = cnt * 10000
stop_frame = (cnt + 1) * 10000
chunk_name = f'{new_name}_{str(cnt).zfill(2)}'
output_file = Path(new_path) / f'{chunk_name}.h5'
print(f'Processing chunk #{cnt}: frames {start_frame}-{stop_frame}')
try:
with h5py.File(output_file, 'w') as sliced_trajectory:
slice_h5traj(trajectory, sliced_trajectory, start_frame, stop_frame)
# Plot using the saved file path
plot_traj(str(output_file), None, chunk_name, new_path, sys_name)
except Exception as e:
print(f'Error creating chunk {cnt}: {e}')
else:
print("Error: Please provide either (start AND end) OR count, not both or neither.")
if __name__ == '__main__':
# Main parser with shared arguments
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="Concatenate or slice trajectories in HDF5 format, "
"or plot basepair distances to show the transition between WCF and HG base pairing."
"Example usage:"
"python process.py -ip /path/to/data -on output_name -sys NAME concatenate -fl file1.h5 file2.h5 file3.h5"
"python script.py -ip /path/to/data -on sliced_traj -sys NAME slice trajectory_name --frames 1000 5000"
"python script.py -ip /path/to/data -on sliced_traj -sys NAME slice trajectory_name -c 20"
"python script.py -ip /path/to/data -on plot_output -sys NAME plot trajectory_name")
# Shared arguments that all commands need
parser.add_argument('-ip', '--in_path', type=Path, required=True,
help='Directory containing trajectory file(s).')
parser.add_argument('-on', '--out_name', type=str, required=True,
help='Output file name (without extension).')
parser.add_argument('-op', '--out_path', type=Path, required=False, default=None,
help='Directory to store output file(s). Default: same as input.')
parser.add_argument('-sys', '--system_name', type=str, required=False, default=None,
help='System name for plot title. Default: empty.')
# Create subparsers
subparsers = parser.add_subparsers(dest='command', help='Available commands', required=True)
# Concatenate command
concatenate_parser = subparsers.add_parser('concatenate',
help='Concatenate multiple trajectory files.')
concatenate_parser.add_argument('-fl', '--file_list', type=str, nargs='+', required=True,
help='List of HDF5 trajectory files to concatenate.')
# Slice command
slice_parser = subparsers.add_parser('slice',
help='Slice a trajectory file. Provide either start/end frames OR count.')
slice_parser.add_argument('trajectory_file', type=str,
help='Name of the MDTraj HDF5 trajectory file to slice (without .h5 extension).')
# Create mutually exclusive group for slice options
slice_group = slice_parser.add_mutually_exclusive_group(required=True)
slice_group.add_argument('--frames', nargs=2, type=int, metavar=('START', 'END'),
help='Start and end frame numbers for single slice.')
slice_group.add_argument('-c', '--count', type=int,
help='Number of 10000-frame chunks to create.')
# Plot command
plot_parser = subparsers.add_parser('plot',
help='Plot basepair distances of trajectory.')
plot_parser.add_argument('trajectory_file', type=str,
help='Name of the MDTraj HDF5 trajectory file to plot (without .h5 extension).')
args = parser.parse_args()
# Execute based on command
if args.command == 'concatenate':
# Convert file list to full paths
full_file_list = [str(args.in_path / f) if not f.endswith('.h5') else str(args.in_path / f)
for f in args.file_list]
concatenate_slices(full_file_list, args.in_path, args.out_path, args.out_name, args.system_name)
elif args.command == 'slice':
if hasattr(args, 'frames') and args.frames:
start, end = args.frames
slice_trajectory(args.trajectory_file, args.out_name, args.in_path, args.out_path,
start, end, None, args.system_name)
else:
slice_trajectory(args.trajectory_file, args.out_name, args.in_path, args.out_path,
None, None, args.count, args.system_name)
elif args.command == 'plot':
plot_traj(args.trajectory_file + '.h5', args.in_path, args.out_name, args.out_path, args.system_name)
# Example usage:
# python script.py -ip /path/to/data -on output_name -sys NAME concatenate -fl file1.h5 file2.h5 file3.h5
# python script.py -ip /path/to/data -on sliced_traj -sys NAME slice trajectory_name --frames 1000 5000
# python script.py -ip /path/to/data -on sliced_traj -sys NAME slice trajectory_name -c 20
# python script.py -ip /path/to/data -on plot_output -sys NAME plot trajectory_name