3030from xfel .clustering .cluster_groups import unit_cell_info
3131from cctbx import crystal
3232from dials .command_line .combine_experiments import CombineWithReference
33+ from dials .util .mp import multi_node_parallel_map
3334
3435try :
3536 from typing import List
@@ -95,33 +96,170 @@ def _index_one(experiment, refl, params, method_list, expt_no):
9596 return idxr .refined_experiments , idxr .refined_reflections , expt_no
9697
9798
99+ def run_with_disabled_logs (fn , fnargs ):
100+ sys .stdout = open (os .devnull , "w" ) # block printing from rstbx
101+ log1 = logging .getLogger ("dials.algorithms.refinement.reflection_processor" )
102+ log2 = logging .getLogger ("dials.algorithms.refinement.refiner" )
103+ log8 = logging .getLogger ("dials.algorithms.refinement.reflection_manager" )
104+ log3 = logging .getLogger ("dials.algorithms.indexing.stills_indexer" )
105+ log4 = logging .getLogger ("dials.algorithms.indexing.nave_parameters" )
106+ log5 = logging .getLogger (
107+ "dials.algorithms.indexing.basis_vector_search.real_space_grid_search"
108+ )
109+ log6 = logging .getLogger (
110+ "dials.algorithms.indexing.basis_vector_search.combinations"
111+ )
112+ log7 = logging .getLogger ("dials.algorithms.indexing.indexer" )
113+ with LoggingContext (log1 , level = logging .ERROR ):
114+ with LoggingContext (log2 , level = logging .ERROR ):
115+ with LoggingContext (log3 , level = logging .ERROR ):
116+ with LoggingContext (log4 , level = logging .ERROR ):
117+ with LoggingContext (log5 , level = logging .ERROR ):
118+ with LoggingContext (log6 , level = logging .ERROR ):
119+ with LoggingContext (log7 , level = logging .ERROR ):
120+ with LoggingContext (log8 , level = logging .ERROR ):
121+ result = fn (* fnargs )
122+ sys .stdout = sys .__stdout__ # restore printing
123+ return result
124+
125+
126+ class Processor (object ):
127+ # Wrap some functions into a class to allow multiprocessing with
128+ # multi_node_parallel_map
129+
130+ def __init__ (self , experiments , reflections , params ):
131+ self .experiments = experiments
132+ self .reflections = reflections
133+ self .params = params
134+ self ._results_order = np .array ([], dtype = np .int32 )
135+ self ._results = defaultdict (list )
136+ self ._n_strong = np .array ([table .size () for table in self .reflections ])
137+ self ._n_found = 0
138+ self ._tables_list = []
139+ self ._expts_list = []
140+ self .indexed_experiments = ExperimentList ()
141+ self .indexed_reflections = flex .reflection_table ()
142+ self .summary_table = ""
143+
144+ def process_output (self , result ):
145+ idx_expts , idx_refl , index = result [0 ], result [1 ], result [2 ]
146+ if idx_expts :
147+ self ._results_order = np .append (self ._results_order , [index ])
148+ ids_map = dict (idx_refl .experiment_identifiers ())
149+ path = idx_expts [0 ].imageset .get_path (index )
150+ for n_cryst , id_ in enumerate (ids_map .keys ()):
151+ selr = idx_refl .select (idx_refl ["id" ] == id_ )
152+ calx , caly , calz = selr ["xyzcal.px" ].parts ()
153+ obsx , obsy , obsz = selr ["xyzobs.px.value" ].parts ()
154+ delpsi = selr ["delpsical.rad" ]
155+ rmsd_x = flex .mean ((calx - obsx ) ** 2 ) ** 0.5
156+ rmsd_y = flex .mean ((caly - obsy ) ** 2 ) ** 0.5
157+ rmsd_z = flex .mean (((delpsi ) * RAD2DEG ) ** 2 ) ** 0.5
158+ n_id_ = calx .size ()
159+ n_indexed = f"{ n_id_ } /{ self ._n_strong [index ]} ({ 100 * n_id_ / self ._n_strong [index ]:2.1f} %)"
160+ self ._results [index ].append (
161+ [
162+ path .split ("/" )[- 1 ],
163+ n_indexed ,
164+ f"{ rmsd_x :.3f} " ,
165+ f"{ rmsd_y :.3f} " ,
166+ f" { rmsd_z :.4f} " ,
167+ ]
168+ )
169+ self ._n_found += len (ids_map .keys ())
170+ self ._tables_list .append (idx_refl )
171+ elist = ExperimentList ()
172+ for jexpt in idx_expts :
173+ elist .append (
174+ Experiment (
175+ identifier = jexpt .identifier ,
176+ beam = jexpt .beam ,
177+ detector = jexpt .detector ,
178+ scan = jexpt .scan ,
179+ goniometer = jexpt .goniometer ,
180+ crystal = jexpt .crystal ,
181+ imageset = jexpt .imageset [index : index + 1 ],
182+ )
183+ )
184+ self ._expts_list .append (elist )
185+
186+ def process (self , method_list ):
187+ inputs = []
188+ for i , (expt , refl ) in enumerate (zip (self .experiments , self .reflections )):
189+ inputs .append ((expt , refl , self .params , method_list , i ))
190+ mp_nproc = self .params .indexing .nproc
191+ mp_njobs = 1
192+ mp_method = "multiprocessing"
193+
194+ def execute_task (input_ ):
195+ return _index_one (* input_ )
196+
197+ if mp_nproc > 1 :
198+ multi_node_parallel_map (
199+ func = execute_task ,
200+ iterable = inputs ,
201+ njobs = mp_njobs ,
202+ nproc = mp_nproc ,
203+ callback = self .process_output ,
204+ cluster_method = mp_method ,
205+ preserve_order = True ,
206+ )
207+ else :
208+ for input_ in inputs :
209+ self .process_output (execute_task (input_ ))
210+
211+ def finalize (self ):
212+ # Results came in a non-deterministic order. So sort them out
213+ # to match the input order, adjusting the ids as appropriate.
214+ n_tot = 0
215+ for i in np .argsort (self ._results_order ):
216+ table = self ._tables_list [i ]
217+ expts = self ._expts_list [i ]
218+ self .indexed_experiments .extend (expts )
219+ ids_map = dict (table .experiment_identifiers ())
220+ for k in table .experiment_identifiers ().keys ():
221+ del table .experiment_identifiers ()[k ]
222+ table ["id" ] += n_tot
223+ for k , v in ids_map .items ():
224+ table .experiment_identifiers ()[k + n_tot ] = v
225+ n_tot += len (ids_map .keys ())
226+ self .indexed_reflections .extend (table )
227+ self .indexed_reflections .assert_experiment_identifiers_are_consistent (
228+ self .indexed_experiments
229+ )
230+
231+ # Add a few extra useful items to the summary table.
232+ overall_summary_header = [
233+ "Image" ,
234+ "expt_id" ,
235+ "n_indexed" ,
236+ "RMSD_X" ,
237+ "RMSD_Y" ,
238+ "RMSD_dPsi" ,
239+ ]
240+
241+ rows = []
242+ total = 0
243+ if self .params .indexing .multiple_lattice_search .max_lattices > 1 :
244+ show_lattices = True
245+ overall_summary_header .insert (1 , "lattice" )
246+ else :
247+ show_lattices = False
248+ for i , k in enumerate (sorted (self ._results .keys ())):
249+ for j , cryst in enumerate (self ._results [k ]):
250+ cryst .insert (1 , total )
251+ if show_lattices :
252+ cryst .insert (1 , j + 1 )
253+ rows .append (cryst )
254+ total += 1
255+
256+ self .summary_table = tabulate (rows , overall_summary_header )
257+
258+
98259def index (experiments , observed , params ):
99260 params .refinement .parameterisation .scan_varying = False
100261 params .indexing .stills .indexer = "stills"
101262
102- def run_with_disabled_logs (fn , fnargs ):
103- sys .stdout = open (os .devnull , "w" ) # block printing from rstbx
104- log1 = logging .getLogger ("dials.algorithms.refinement.reflection_manager" )
105- log2 = logging .getLogger ("dials.algorithms.refinement.refiner" )
106- log3 = logging .getLogger ("dials.algorithms.indexing.stills_indexer" )
107- log4 = logging .getLogger ("dials.algorithms.indexing.nave_parameters" )
108- log5 = logging .getLogger (
109- "dials.algorithms.indexing.basis_vector_search.real_space_grid_search"
110- )
111- log6 = logging .getLogger (
112- "dials.algorithms.indexing.basis_vector_search.combinations"
113- )
114- log7 = logging .getLogger ("dials.algorithms.indexing.indexer" )
115- with LoggingContext (log1 , level = logging .ERROR ):
116- with LoggingContext (log2 , level = logging .ERROR ):
117- with LoggingContext (log3 , level = logging .ERROR ):
118- with LoggingContext (log4 , level = logging .ERROR ):
119- with LoggingContext (log5 , level = logging .ERROR ):
120- with LoggingContext (log6 , level = logging .ERROR ):
121- with LoggingContext (log7 , level = logging .ERROR ):
122- return fn (* fnargs )
123- sys .stdout = sys .__stdout__ # restore printing
124-
125263 reflections = observed .split_by_experiment_id ()
126264 # Calculate necessary quantities
127265 for refl , experiment in zip (reflections , experiments ):
@@ -142,9 +280,6 @@ def run_with_disabled_logs(fn, fnargs):
142280 logger .info (f"Setting max cell to { max (max_cells ):.1f} " + "\u212B " )
143281 params .indexing .max_cell = max (max_cells )
144282
145- n_strong = np .array ([table .size () for table in reflections ])
146- indexed_experiments = ExperimentList ()
147- indexed_reflections = flex .reflection_table ()
148283 method_list = params .method
149284 if "real_space_grid_search" in method_list :
150285 if not params .indexing .known_symmetry .unit_cell :
@@ -155,110 +290,13 @@ def run_with_disabled_logs(fn, fnargs):
155290 logger .info (f"Attempting indexing with { methods } method{ pl } " )
156291
157292 def index_all (experiments , reflections , params ):
158- n_found = 0
159- overall_summary_header = [
160- "Image" ,
161- "expt_id" ,
162- "n_indexed" ,
163- "RMSD_X" ,
164- "RMSD_Y" ,
165- "RMSD_dPsi" ,
166- ]
167-
168- futures = []
169- results = defaultdict (list )
170- tables_list = []
171- expts_list = []
172- results_order = np .array ([], dtype = np .int32 )
173-
174- with concurrent .futures .ProcessPoolExecutor (
175- max_workers = params .indexing .nproc
176- ) as pool :
177- for i , (expt , refl ) in enumerate (zip (experiments , reflections )):
178- futures .append (
179- pool .submit (_index_one , expt , refl , params , method_list , i )
180- )
181- for future in concurrent .futures .as_completed (futures ):
182- idx_expts , idx_refl , index = future .result ()
183- if idx_expts :
184- results_order = np .append (results_order , [index ])
185- ids_map = dict (idx_refl .experiment_identifiers ())
186- path = expt .imageset .get_path (index )
187- for n_cryst , id_ in enumerate (ids_map .keys ()):
188- selr = idx_refl .select (idx_refl ["id" ] == id_ )
189- calx , caly , calz = selr ["xyzcal.px" ].parts ()
190- obsx , obsy , obsz = selr ["xyzobs.px.value" ].parts ()
191- delpsi = selr ["delpsical.rad" ]
192- rmsd_x = flex .mean ((calx - obsx ) ** 2 ) ** 0.5
193- rmsd_y = flex .mean ((caly - obsy ) ** 2 ) ** 0.5
194- rmsd_z = flex .mean (((delpsi ) * RAD2DEG ) ** 2 ) ** 0.5
195- n_id_ = calx .size ()
196- n_indexed = f"{ n_id_ } /{ n_strong [index ]} ({ 100 * n_id_ / n_strong [index ]:2.1f} %)"
197- results [index ].append (
198- [
199- path .split ("/" )[- 1 ],
200- n_indexed ,
201- f"{ rmsd_x :.3f} " ,
202- f"{ rmsd_y :.3f} " ,
203- f" { rmsd_z :.4f} " ,
204- ]
205- )
206- n_found += len (ids_map .keys ())
207- tables_list .append (idx_refl )
208- elist = ExperimentList ()
209- for jexpt in idx_expts :
210- elist .append (
211- Experiment (
212- identifier = jexpt .identifier ,
213- beam = jexpt .beam ,
214- detector = jexpt .detector ,
215- scan = jexpt .scan ,
216- goniometer = jexpt .goniometer ,
217- crystal = jexpt .crystal ,
218- imageset = jexpt .imageset [index : index + 1 ],
219- )
220- )
221- expts_list .append (elist )
222-
223- # Results came in a non-deterministic order. So sort them out
224- # to match the input order, adjusting the ids as appropriate.
225- n_tot = 0
226- for i in np .argsort (results_order ):
227- table = tables_list [i ]
228- expts = expts_list [i ]
229- indexed_experiments .extend (expts )
230- ids_map = dict (table .experiment_identifiers ())
231- for k in table .experiment_identifiers ().keys ():
232- del table .experiment_identifiers ()[k ]
233- table ["id" ] += n_tot
234- for k , v in ids_map .items ():
235- table .experiment_identifiers ()[k + n_tot ] = v
236- n_tot += len (ids_map .keys ())
237- indexed_reflections .extend (table )
238- indexed_reflections .assert_experiment_identifiers_are_consistent (
239- indexed_experiments
240- )
241-
242- # Add a few extra useful items to the summary table.
243- rows = []
244- total = 0
245- if params .indexing .multiple_lattice_search .max_lattices > 1 :
246- show_lattices = True
247- overall_summary_header .insert (1 , "lattice" )
248- else :
249- show_lattices = False
250- for i , k in enumerate (sorted (results .keys ())):
251- for j , cryst in enumerate (results [k ]):
252- cryst .insert (1 , total )
253- if show_lattices :
254- cryst .insert (1 , j + 1 )
255- rows .append (cryst )
256- total += 1
257-
293+ processor = Processor (experiments , reflections , params )
294+ processor .process (method_list )
295+ processor .finalize ()
258296 return (
259- indexed_experiments ,
260- indexed_reflections ,
261- tabulate ( rows , overall_summary_header ) ,
297+ processor . indexed_experiments ,
298+ processor . indexed_reflections ,
299+ processor . summary_table ,
262300 )
263301
264302 indexed_experiments , indexed_reflections , summary = run_with_disabled_logs (
@@ -288,8 +326,7 @@ def index_all(experiments, reflections, params):
288326 len (indexed_experiments .beams ())
289327 ) > 1 :
290328 combine = CombineWithReference (
291- detector = indexed_experiments [0 ].detector ,
292- beam = indexed_experiments [0 ].beam ,
329+ detector = indexed_experiments [0 ].detector , beam = indexed_experiments [0 ].beam
293330 )
294331 elist = ExperimentList ()
295332 for expt in indexed_experiments :
0 commit comments