@@ -101,3 +101,85 @@ def test_fit_grains(
101101 )
102102
103103 assert cresult
104+
105+
106+ def test_fit_grains_return_pull_spots_data (
107+ single_ge_include_path : Path ,
108+ test_config : config .root .RootConfig ,
109+ grains_reference_file_path : Path ,
110+ ) -> None :
111+ os .chdir (str (single_ge_include_path ))
112+
113+ grains_table : np .ndarray = np .loadtxt (grains_reference_file_path , ndmin = 2 )
114+
115+ result = fit_grains (
116+ test_config ,
117+ grains_table ,
118+ show_progress = False ,
119+ ids_to_refine = None ,
120+ write_spots_files = False ,
121+ return_pull_spots_data = True ,
122+ )
123+
124+ # Should return a (fit_results, spots_data) tuple
125+ assert isinstance (result , tuple )
126+ assert len (result ) == 2
127+
128+ fit_results , spots_data = result
129+
130+ # fit_results should be a list of 4-element tuples
131+ assert isinstance (fit_results , list )
132+ assert len (fit_results ) > 0
133+ for grain_result in fit_results :
134+ assert len (grain_result ) == 4
135+ grain_id , completeness , chisq , grain_params = grain_result
136+ assert isinstance (grain_id , (int , np .integer ))
137+ assert isinstance (completeness , float )
138+ assert isinstance (grain_params , np .ndarray )
139+ assert grain_params .shape == (12 ,)
140+
141+ # spots_data should be a dict keyed by grain_id
142+ assert isinstance (spots_data , dict )
143+ assert len (spots_data ) == len (fit_results )
144+
145+ for grain_id , (complvec , results ) in spots_data .items ():
146+ # complvec is a list of booleans
147+ assert isinstance (complvec , list )
148+
149+ # results is a dict keyed by detector name
150+ assert isinstance (results , dict )
151+ assert len (results ) > 0
152+
153+ for det_key , det_results in results .items ():
154+ assert isinstance (det_key , str )
155+ assert isinstance (det_results , list )
156+ assert len (det_results ) > 0
157+
158+ for spot in det_results :
159+ # Each spot should have 9 elements (including pred_xy)
160+ assert len (spot ) == 9 , (
161+ f'Expected 9 elements per spot, got { len (spot )} '
162+ )
163+
164+ peak_id = spot [0 ]
165+ hkl = spot [2 ]
166+ pred_angs = spot [5 ]
167+ meas_angs = spot [6 ]
168+ meas_xy = spot [7 ]
169+ pred_xy = spot [8 ]
170+
171+ assert isinstance (peak_id , (int , np .integer ))
172+ assert isinstance (hkl , np .ndarray )
173+ assert hkl .shape == (3 ,)
174+ assert isinstance (pred_angs , np .ndarray )
175+ assert pred_angs .shape == (3 ,)
176+
177+ # meas_angs/meas_xy may be None for invalid spots
178+ if peak_id >= 0 :
179+ assert isinstance (meas_angs , np .ndarray )
180+ assert meas_angs .shape == (3 ,)
181+ assert isinstance (meas_xy , np .ndarray )
182+ assert meas_xy .shape == (2 ,)
183+
184+ assert isinstance (pred_xy , np .ndarray )
185+ assert pred_xy .shape == (2 ,)
0 commit comments