Skip to content

Commit ac4ef84

Browse files
committed
Enough to make it run
1 parent 443164c commit ac4ef84

File tree

1 file changed

+83
-47
lines changed

1 file changed

+83
-47
lines changed

src/cryoemservices/services/extract_subtomo.py

Lines changed: 83 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import ast
34
from pathlib import Path
45

56
import mrcfile
@@ -22,9 +23,8 @@ class ExtractSubTomoParameters(BaseModel):
2223
tilt_alignment_file: str = Field(..., min_length=1)
2324
newstack_file: str = Field(..., min_length=1)
2425
output_star: str = Field(..., min_length=1)
25-
scaled_tomogram_shape: list[int]
26+
scaled_tomogram_shape: list[int] | str
2627
pixel_size: float
27-
refined_tilt_axis: float
2828
particle_diameter: float = 0
2929
boxsize: int = 256
3030
small_boxsize: int = 64
@@ -39,9 +39,15 @@ class ExtractSubTomoParameters(BaseModel):
3939
@field_validator("scaled_tomogram_shape")
4040
@classmethod
4141
def check_shape_is_3d(cls, v):
42-
if len(v) != 3:
42+
if not len(v):
43+
raise ValueError("Tomogram shape not given")
44+
if type(v) is str:
45+
shape_list = ast.literal_eval(v)
46+
else:
47+
shape_list = v
48+
if len(shape_list) != 3:
4349
raise ValueError("Tomogram shape must be 3D")
44-
return v
50+
return shape_list
4551

4652

4753
class ExtractSubTomo(CommonService):
@@ -105,7 +111,7 @@ def extract_subtomo(self, rw, header: dict, message: dict):
105111
self.log.info(
106112
f"Inputs: {extract_subtomo_params.tilt_alignment_file}, "
107113
f"{extract_subtomo_params.cbox_3d_file} "
108-
f"Output: {extract_subtomo_params.output_file}"
114+
f"Output: {extract_subtomo_params.output_star}"
109115
)
110116

111117
# Update the relion options and get box sizes
@@ -114,8 +120,8 @@ def extract_subtomo(self, rw, header: dict, message: dict):
114120
)
115121

116122
# Make sure the output directory exists
117-
if not Path(extract_subtomo_params.output_file).parent.exists():
118-
Path(extract_subtomo_params.output_file).parent.mkdir(parents=True)
123+
if not Path(extract_subtomo_params.output_star).parent.exists():
124+
Path(extract_subtomo_params.output_star).parent.mkdir(parents=True)
119125

120126
# If no background radius set diameter as 75% of box
121127
if extract_subtomo_params.bg_radius == -1:
@@ -140,14 +146,16 @@ def extract_subtomo(self, rw, header: dict, message: dict):
140146
# Get the shifts between tilts
141147
shift_data = np.genfromtxt(extract_subtomo_params.tilt_alignment_file)
142148
# tilt_ids = shift_data[:, 0].astype(int)
149+
refined_tilt_axis = float(shift_data[0, 1])
143150
x_shifts = shift_data[:, 3].astype(float)
144151
y_shifts = shift_data[:, 4].astype(float)
145152
tilt_count = len(x_shifts)
146153

147-
# Extraction
154+
# Rotation around the tilt axis is about (0, height/2)
155+
# Or possibly not, sometimes seems to be (width/2, height/2), needs exploration
148156
centre_x = 0
149-
centre_y = extract_subtomo_params.scaled_tomogram_shape[1] / 2
150-
tilt_axis_radians = extract_subtomo_params.refined_tilt_axis * np.pi / 180
157+
centre_y = float(extract_subtomo_params.scaled_tomogram_shape[1]) / 2
158+
tilt_axis_radians = (90 - refined_tilt_axis) * np.pi / 180
151159

152160
x_coords_in_tilts = centre_x + (
153161
(particles_x - centre_x) * np.cos(tilt_axis_radians)
@@ -157,6 +165,8 @@ def extract_subtomo(self, rw, header: dict, message: dict):
157165
(particles_x - centre_x) * np.sin(tilt_axis_radians)
158166
+ (particles_y - centre_y) * np.cos(tilt_axis_radians)
159167
)
168+
x_coords_in_tilts *= extract_subtomo_params.tomogram_binning
169+
y_coords_in_tilts *= extract_subtomo_params.tomogram_binning
160170

161171
# Downscaling dimensions
162172
extract_subtomo_params.relion_options.pixel_size_downscaled = (
@@ -171,15 +181,28 @@ def extract_subtomo(self, rw, header: dict, message: dict):
171181
box_len = extract_subtomo_params.relion_options.small_boxsize
172182
pixel_size = extract_subtomo_params.relion_options.pixel_size_downscaled
173183

184+
# Distance of each pixel from the centre for background normalization
185+
grid_indexes = np.meshgrid(
186+
np.arange(2 * scaled_extract_width),
187+
np.arange(2 * scaled_extract_width),
188+
)
189+
distance_from_centre = np.sqrt(
190+
(grid_indexes[0] - scaled_extract_width + 0.5) ** 2
191+
+ (grid_indexes[1] - scaled_extract_width + 0.5) ** 2
192+
)
193+
174194
# Read in tilt images
195+
self.log.info("Reading tilt images")
175196
tilt_images = []
176197
with open(extract_subtomo_params.newstack_file) as ns_file:
177198
while True:
178199
line = ns_file.readline()
179200
if not line:
180201
break
181202
elif line.startswith("/"):
182-
tilt_images.append(line)
203+
tilt_name = line.strip()
204+
with mrcfile.open(tilt_name) as mrc:
205+
tilt_images.append(mrc.data)
183206

184207
for particle in range(len(particles_x)):
185208
output_mrc_stack = np.array([])
@@ -190,33 +213,52 @@ def extract_subtomo(self, rw, header: dict, message: dict):
190213
y_top_pad = 0
191214
y_bot_pad = 0
192215

193-
x_left = (
194-
x_coords_in_tilts[particle] - extract_width - x_shifts[particle]
216+
x_left = round(
217+
x_coords_in_tilts[particle] - extract_width - x_shifts[tilt]
195218
)
196219
if x_left < 0:
197220
x_left_pad = -x_left
198221
x_left = 0
199-
x_right = (
200-
x_coords_in_tilts[particle] + extract_width - x_shifts[particle]
222+
x_right = round(
223+
x_coords_in_tilts[particle] + extract_width - x_shifts[tilt]
201224
)
202-
if x_right >= extract_subtomo_params.scaled_tomogram_shape[1]:
225+
if (
226+
x_right
227+
>= extract_subtomo_params.scaled_tomogram_shape[0]
228+
* extract_subtomo_params.tomogram_binning
229+
):
203230
x_right_pad = (
204-
x_right - extract_subtomo_params.scaled_tomogram_shape[1]
231+
x_right - extract_subtomo_params.scaled_tomogram_shape[0]
232+
)
233+
x_right = (
234+
extract_subtomo_params.scaled_tomogram_shape[0]
235+
* extract_subtomo_params.tomogram_binning
205236
)
206-
x_right = extract_subtomo_params.scaled_tomogram_shape[1]
207-
y_top = y_coords_in_tilts[particle] - extract_width - y_shifts[particle]
237+
y_top = round(
238+
y_coords_in_tilts[particle] - extract_width - y_shifts[tilt]
239+
)
208240
if y_top < 0:
209241
y_top_pad = -y_top
210242
y_top = 0
211-
y_bot = y_coords_in_tilts[particle] + extract_width - y_shifts[particle]
212-
if y_bot >= extract_subtomo_params.scaled_tomogram_shape[0]:
213-
y_bot_pad = y_bot - extract_subtomo_params.scaled_tomogram_shape[0]
214-
y_bot = extract_subtomo_params.scaled_tomogram_shape[0]
243+
y_bot = round(
244+
y_coords_in_tilts[particle] + extract_width - y_shifts[tilt]
245+
)
246+
if (
247+
y_bot
248+
>= extract_subtomo_params.scaled_tomogram_shape[1]
249+
* extract_subtomo_params.tomogram_binning
250+
):
251+
y_bot_pad = y_bot - extract_subtomo_params.scaled_tomogram_shape[1]
252+
y_bot = (
253+
extract_subtomo_params.scaled_tomogram_shape[1]
254+
* extract_subtomo_params.tomogram_binning
255+
)
215256

216-
with mrcfile.open(tilt_images[tilt]) as mrc:
217-
input_micrograph_image = mrc.data
257+
if y_bot <= y_top or x_left >= x_right:
258+
self.log.warning(f"Invalid {tilt} for particle {particle}")
259+
continue
218260

219-
particle_subimage = input_micrograph_image[y_top:y_bot, x_left:x_right]
261+
particle_subimage = tilt_images[tilt][y_top:y_bot, x_left:x_right]
220262
particle_subimage = np.pad(
221263
particle_subimage,
222264
((y_bot_pad, y_top_pad), (x_left_pad, x_right_pad)),
@@ -248,22 +290,12 @@ def extract_subtomo(self, rw, header: dict, message: dict):
248290
)
249291
)
250292

251-
# Distance of each pixel from the centre, compared to background radius
252-
grid_indexes = np.meshgrid(
253-
np.arange(2 * scaled_extract_width),
254-
np.arange(2 * scaled_extract_width),
255-
)
256-
distance_from_centre = np.sqrt(
257-
(grid_indexes[0] - scaled_extract_width + 0.5) ** 2
258-
+ (grid_indexes[1] - scaled_extract_width + 0.5) ** 2
259-
)
293+
# Background normalisation
260294
bg_region = (
261295
distance_from_centre
262296
> np.ones(np.shape(particle_subimage))
263297
* extract_subtomo_params.bg_radius
264298
)
265-
266-
# Background normalisation
267299
bg_mean = np.mean(particle_subimage[bg_region])
268300
bg_std = np.std(particle_subimage[bg_region])
269301
particle_subimage = (particle_subimage - bg_mean) / bg_std
@@ -276,12 +308,16 @@ def extract_subtomo(self, rw, header: dict, message: dict):
276308
else:
277309
output_mrc_stack = np.array([particle_subimage], dtype=np.float32)
278310

311+
if not len(output_mrc_stack):
312+
self.log.warning(f"Could not extract particle {particle}")
313+
continue
314+
279315
# Produce the mrc file of the extracted particles
280316
output_mrc_file = (
281-
Path(extract_subtomo_params.output_file).parent
317+
Path(extract_subtomo_params.output_star).parent
282318
/ f"{particle}_stack2d.mrcs"
283319
)
284-
self.log.info(f"Extracted particle {particle}")
320+
self.log.info(f"Extracted particle {particle} of {len(particles_x)}")
285321
with mrcfile.new(str(output_mrc_file), overwrite=True) as mrc:
286322
mrc.set_data(output_mrc_stack.astype(np.float32))
287323
mrc.header.mx = box_len
@@ -315,41 +351,41 @@ def extract_subtomo(self, rw, header: dict, message: dict):
315351
[
316352
str(
317353
float(particles_x[particle])
318-
- extract_subtomo_params.scaled_tomogram_shape[2]
354+
- float(extract_subtomo_params.scaled_tomogram_shape[2])
319355
/ 2
320356
* extract_subtomo_params.tomogram_binning
321357
),
322358
str(
323359
float(particles_y[particle])
324-
- extract_subtomo_params.scaled_tomogram_shape[1]
360+
- float(extract_subtomo_params.scaled_tomogram_shape[1])
325361
/ 2
326362
* extract_subtomo_params.tomogram_binning
327363
),
328364
str(
329365
float(particles_z[particle])
330-
- extract_subtomo_params.scaled_tomogram_shape[0]
366+
- float(extract_subtomo_params.scaled_tomogram_shape[0])
331367
/ 2
332368
* extract_subtomo_params.tomogram_binning
333369
),
334370
"1",
335-
f"{_get_tilt_name_v5_12(Path(extract_subtomo_params.tomogram))}/{particle}",
336-
f"[{frames}]"
337-
f"{Path(extract_subtomo_params.output_file).parent}/{particle}_stack2d.mrcs",
371+
f"{_get_tilt_name_v5_12(Path(extract_subtomo_params.tilt_alignment_file))}/{particle}",
372+
f"[{frames}]",
373+
f"{Path(extract_subtomo_params.output_star).parent}/{particle}_stack2d.mrcs",
338374
"0.0",
339375
"0.0",
340376
"0.0",
341377
]
342378
)
343379
extracted_parts_doc.write_file(
344-
extract_subtomo_params.output_file, style=cif.Style.Simple
380+
extract_subtomo_params.output_star, style=cif.Style.Simple
345381
)
346382

347383
# Register the extract job with the node creator
348384
self.log.info(f"Sending {self.job_type} to node creator")
349385
node_creator_parameters = {
350386
"job_type": self.job_type,
351387
"input_file": extract_subtomo_params.cbox_3d_file,
352-
"output_file": extract_subtomo_params.output_file,
388+
"output_file": extract_subtomo_params.output_star,
353389
"relion_options": dict(extract_subtomo_params.relion_options),
354390
"command": "",
355391
"stdout": "",

0 commit comments

Comments
 (0)