Skip to content

Commit c16a380

Browse files
Use multi_node_parallel_map for multiprocessing
1 parent f83c7f1 commit c16a380

File tree

1 file changed

+168
-131
lines changed

1 file changed

+168
-131
lines changed

command_line/ssx_index.py

Lines changed: 168 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from xfel.clustering.cluster_groups import unit_cell_info
3131
from cctbx import crystal
3232
from dials.command_line.combine_experiments import CombineWithReference
33+
from dials.util.mp import multi_node_parallel_map
3334

3435
try:
3536
from typing import List
@@ -95,33 +96,170 @@ def _index_one(experiment, refl, params, method_list, expt_no):
9596
return idxr.refined_experiments, idxr.refined_reflections, expt_no
9697

9798

99+
def run_with_disabled_logs(fn, fnargs):
100+
sys.stdout = open(os.devnull, "w") # block printing from rstbx
101+
log1 = logging.getLogger("dials.algorithms.refinement.reflection_processor")
102+
log2 = logging.getLogger("dials.algorithms.refinement.refiner")
103+
log8 = logging.getLogger("dials.algorithms.refinement.reflection_manager")
104+
log3 = logging.getLogger("dials.algorithms.indexing.stills_indexer")
105+
log4 = logging.getLogger("dials.algorithms.indexing.nave_parameters")
106+
log5 = logging.getLogger(
107+
"dials.algorithms.indexing.basis_vector_search.real_space_grid_search"
108+
)
109+
log6 = logging.getLogger(
110+
"dials.algorithms.indexing.basis_vector_search.combinations"
111+
)
112+
log7 = logging.getLogger("dials.algorithms.indexing.indexer")
113+
with LoggingContext(log1, level=logging.ERROR):
114+
with LoggingContext(log2, level=logging.ERROR):
115+
with LoggingContext(log3, level=logging.ERROR):
116+
with LoggingContext(log4, level=logging.ERROR):
117+
with LoggingContext(log5, level=logging.ERROR):
118+
with LoggingContext(log6, level=logging.ERROR):
119+
with LoggingContext(log7, level=logging.ERROR):
120+
with LoggingContext(log8, level=logging.ERROR):
121+
result = fn(*fnargs)
122+
sys.stdout = sys.__stdout__ # restore printing
123+
return result
124+
125+
126+
class Processor(object):
127+
# Wrap some functions into a class to allow multiprocessing with
128+
# multi_node_parallel_map
129+
130+
def __init__(self, experiments, reflections, params):
131+
self.experiments = experiments
132+
self.reflections = reflections
133+
self.params = params
134+
self._results_order = np.array([], dtype=np.int32)
135+
self._results = defaultdict(list)
136+
self._n_strong = np.array([table.size() for table in self.reflections])
137+
self._n_found = 0
138+
self._tables_list = []
139+
self._expts_list = []
140+
self.indexed_experiments = ExperimentList()
141+
self.indexed_reflections = flex.reflection_table()
142+
self.summary_table = ""
143+
144+
def process_output(self, result):
145+
idx_expts, idx_refl, index = result[0], result[1], result[2]
146+
if idx_expts:
147+
self._results_order = np.append(self._results_order, [index])
148+
ids_map = dict(idx_refl.experiment_identifiers())
149+
path = idx_expts[0].imageset.get_path(index)
150+
for n_cryst, id_ in enumerate(ids_map.keys()):
151+
selr = idx_refl.select(idx_refl["id"] == id_)
152+
calx, caly, calz = selr["xyzcal.px"].parts()
153+
obsx, obsy, obsz = selr["xyzobs.px.value"].parts()
154+
delpsi = selr["delpsical.rad"]
155+
rmsd_x = flex.mean((calx - obsx) ** 2) ** 0.5
156+
rmsd_y = flex.mean((caly - obsy) ** 2) ** 0.5
157+
rmsd_z = flex.mean(((delpsi) * RAD2DEG) ** 2) ** 0.5
158+
n_id_ = calx.size()
159+
n_indexed = f"{n_id_}/{self._n_strong[index]} ({100*n_id_/self._n_strong[index]:2.1f}%)"
160+
self._results[index].append(
161+
[
162+
path.split("/")[-1],
163+
n_indexed,
164+
f"{rmsd_x:.3f}",
165+
f"{rmsd_y:.3f}",
166+
f" {rmsd_z:.4f}",
167+
]
168+
)
169+
self._n_found += len(ids_map.keys())
170+
self._tables_list.append(idx_refl)
171+
elist = ExperimentList()
172+
for jexpt in idx_expts:
173+
elist.append(
174+
Experiment(
175+
identifier=jexpt.identifier,
176+
beam=jexpt.beam,
177+
detector=jexpt.detector,
178+
scan=jexpt.scan,
179+
goniometer=jexpt.goniometer,
180+
crystal=jexpt.crystal,
181+
imageset=jexpt.imageset[index : index + 1],
182+
)
183+
)
184+
self._expts_list.append(elist)
185+
186+
def process(self, method_list):
187+
inputs = []
188+
for i, (expt, refl) in enumerate(zip(self.experiments, self.reflections)):
189+
inputs.append((expt, refl, self.params, method_list, i))
190+
mp_nproc = self.params.indexing.nproc
191+
mp_njobs = 1
192+
mp_method = "multiprocessing"
193+
194+
def execute_task(input_):
195+
return _index_one(*input_)
196+
197+
if mp_nproc > 1:
198+
multi_node_parallel_map(
199+
func=execute_task,
200+
iterable=inputs,
201+
njobs=mp_njobs,
202+
nproc=mp_nproc,
203+
callback=self.process_output,
204+
cluster_method=mp_method,
205+
preserve_order=True,
206+
)
207+
else:
208+
for input_ in inputs:
209+
self.process_output(execute_task(input_))
210+
211+
def finalize(self):
212+
# Results came in a non-deterministic order. So sort them out
213+
# to match the input order, adjusting the ids as appropriate.
214+
n_tot = 0
215+
for i in np.argsort(self._results_order):
216+
table = self._tables_list[i]
217+
expts = self._expts_list[i]
218+
self.indexed_experiments.extend(expts)
219+
ids_map = dict(table.experiment_identifiers())
220+
for k in table.experiment_identifiers().keys():
221+
del table.experiment_identifiers()[k]
222+
table["id"] += n_tot
223+
for k, v in ids_map.items():
224+
table.experiment_identifiers()[k + n_tot] = v
225+
n_tot += len(ids_map.keys())
226+
self.indexed_reflections.extend(table)
227+
self.indexed_reflections.assert_experiment_identifiers_are_consistent(
228+
self.indexed_experiments
229+
)
230+
231+
# Add a few extra useful items to the summary table.
232+
overall_summary_header = [
233+
"Image",
234+
"expt_id",
235+
"n_indexed",
236+
"RMSD_X",
237+
"RMSD_Y",
238+
"RMSD_dPsi",
239+
]
240+
241+
rows = []
242+
total = 0
243+
if self.params.indexing.multiple_lattice_search.max_lattices > 1:
244+
show_lattices = True
245+
overall_summary_header.insert(1, "lattice")
246+
else:
247+
show_lattices = False
248+
for i, k in enumerate(sorted(self._results.keys())):
249+
for j, cryst in enumerate(self._results[k]):
250+
cryst.insert(1, total)
251+
if show_lattices:
252+
cryst.insert(1, j + 1)
253+
rows.append(cryst)
254+
total += 1
255+
256+
self.summary_table = tabulate(rows, overall_summary_header)
257+
258+
98259
def index(experiments, observed, params):
99260
params.refinement.parameterisation.scan_varying = False
100261
params.indexing.stills.indexer = "stills"
101262

102-
def run_with_disabled_logs(fn, fnargs):
103-
sys.stdout = open(os.devnull, "w") # block printing from rstbx
104-
log1 = logging.getLogger("dials.algorithms.refinement.reflection_manager")
105-
log2 = logging.getLogger("dials.algorithms.refinement.refiner")
106-
log3 = logging.getLogger("dials.algorithms.indexing.stills_indexer")
107-
log4 = logging.getLogger("dials.algorithms.indexing.nave_parameters")
108-
log5 = logging.getLogger(
109-
"dials.algorithms.indexing.basis_vector_search.real_space_grid_search"
110-
)
111-
log6 = logging.getLogger(
112-
"dials.algorithms.indexing.basis_vector_search.combinations"
113-
)
114-
log7 = logging.getLogger("dials.algorithms.indexing.indexer")
115-
with LoggingContext(log1, level=logging.ERROR):
116-
with LoggingContext(log2, level=logging.ERROR):
117-
with LoggingContext(log3, level=logging.ERROR):
118-
with LoggingContext(log4, level=logging.ERROR):
119-
with LoggingContext(log5, level=logging.ERROR):
120-
with LoggingContext(log6, level=logging.ERROR):
121-
with LoggingContext(log7, level=logging.ERROR):
122-
return fn(*fnargs)
123-
sys.stdout = sys.__stdout__ # restore printing
124-
125263
reflections = observed.split_by_experiment_id()
126264
# Calculate necessary quantities
127265
for refl, experiment in zip(reflections, experiments):
@@ -142,9 +280,6 @@ def run_with_disabled_logs(fn, fnargs):
142280
logger.info(f"Setting max cell to {max(max_cells):.1f} " + "\u212B")
143281
params.indexing.max_cell = max(max_cells)
144282

145-
n_strong = np.array([table.size() for table in reflections])
146-
indexed_experiments = ExperimentList()
147-
indexed_reflections = flex.reflection_table()
148283
method_list = params.method
149284
if "real_space_grid_search" in method_list:
150285
if not params.indexing.known_symmetry.unit_cell:
@@ -155,110 +290,13 @@ def run_with_disabled_logs(fn, fnargs):
155290
logger.info(f"Attempting indexing with {methods} method{pl}")
156291

157292
def index_all(experiments, reflections, params):
158-
n_found = 0
159-
overall_summary_header = [
160-
"Image",
161-
"expt_id",
162-
"n_indexed",
163-
"RMSD_X",
164-
"RMSD_Y",
165-
"RMSD_dPsi",
166-
]
167-
168-
futures = []
169-
results = defaultdict(list)
170-
tables_list = []
171-
expts_list = []
172-
results_order = np.array([], dtype=np.int32)
173-
174-
with concurrent.futures.ProcessPoolExecutor(
175-
max_workers=params.indexing.nproc
176-
) as pool:
177-
for i, (expt, refl) in enumerate(zip(experiments, reflections)):
178-
futures.append(
179-
pool.submit(_index_one, expt, refl, params, method_list, i)
180-
)
181-
for future in concurrent.futures.as_completed(futures):
182-
idx_expts, idx_refl, index = future.result()
183-
if idx_expts:
184-
results_order = np.append(results_order, [index])
185-
ids_map = dict(idx_refl.experiment_identifiers())
186-
path = expt.imageset.get_path(index)
187-
for n_cryst, id_ in enumerate(ids_map.keys()):
188-
selr = idx_refl.select(idx_refl["id"] == id_)
189-
calx, caly, calz = selr["xyzcal.px"].parts()
190-
obsx, obsy, obsz = selr["xyzobs.px.value"].parts()
191-
delpsi = selr["delpsical.rad"]
192-
rmsd_x = flex.mean((calx - obsx) ** 2) ** 0.5
193-
rmsd_y = flex.mean((caly - obsy) ** 2) ** 0.5
194-
rmsd_z = flex.mean(((delpsi) * RAD2DEG) ** 2) ** 0.5
195-
n_id_ = calx.size()
196-
n_indexed = f"{n_id_}/{n_strong[index]} ({100*n_id_/n_strong[index]:2.1f}%)"
197-
results[index].append(
198-
[
199-
path.split("/")[-1],
200-
n_indexed,
201-
f"{rmsd_x:.3f}",
202-
f"{rmsd_y:.3f}",
203-
f" {rmsd_z:.4f}",
204-
]
205-
)
206-
n_found += len(ids_map.keys())
207-
tables_list.append(idx_refl)
208-
elist = ExperimentList()
209-
for jexpt in idx_expts:
210-
elist.append(
211-
Experiment(
212-
identifier=jexpt.identifier,
213-
beam=jexpt.beam,
214-
detector=jexpt.detector,
215-
scan=jexpt.scan,
216-
goniometer=jexpt.goniometer,
217-
crystal=jexpt.crystal,
218-
imageset=jexpt.imageset[index : index + 1],
219-
)
220-
)
221-
expts_list.append(elist)
222-
223-
# Results came in a non-deterministic order. So sort them out
224-
# to match the input order, adjusting the ids as appropriate.
225-
n_tot = 0
226-
for i in np.argsort(results_order):
227-
table = tables_list[i]
228-
expts = expts_list[i]
229-
indexed_experiments.extend(expts)
230-
ids_map = dict(table.experiment_identifiers())
231-
for k in table.experiment_identifiers().keys():
232-
del table.experiment_identifiers()[k]
233-
table["id"] += n_tot
234-
for k, v in ids_map.items():
235-
table.experiment_identifiers()[k + n_tot] = v
236-
n_tot += len(ids_map.keys())
237-
indexed_reflections.extend(table)
238-
indexed_reflections.assert_experiment_identifiers_are_consistent(
239-
indexed_experiments
240-
)
241-
242-
# Add a few extra useful items to the summary table.
243-
rows = []
244-
total = 0
245-
if params.indexing.multiple_lattice_search.max_lattices > 1:
246-
show_lattices = True
247-
overall_summary_header.insert(1, "lattice")
248-
else:
249-
show_lattices = False
250-
for i, k in enumerate(sorted(results.keys())):
251-
for j, cryst in enumerate(results[k]):
252-
cryst.insert(1, total)
253-
if show_lattices:
254-
cryst.insert(1, j + 1)
255-
rows.append(cryst)
256-
total += 1
257-
293+
processor = Processor(experiments, reflections, params)
294+
processor.process(method_list)
295+
processor.finalize()
258296
return (
259-
indexed_experiments,
260-
indexed_reflections,
261-
tabulate(rows, overall_summary_header),
297+
processor.indexed_experiments,
298+
processor.indexed_reflections,
299+
processor.summary_table,
262300
)
263301

264302
indexed_experiments, indexed_reflections, summary = run_with_disabled_logs(
@@ -288,8 +326,7 @@ def index_all(experiments, reflections, params):
288326
len(indexed_experiments.beams())
289327
) > 1:
290328
combine = CombineWithReference(
291-
detector=indexed_experiments[0].detector,
292-
beam=indexed_experiments[0].beam,
329+
detector=indexed_experiments[0].detector, beam=indexed_experiments[0].beam
293330
)
294331
elist = ExperimentList()
295332
for expt in indexed_experiments:

0 commit comments

Comments
 (0)