@@ -741,10 +741,10 @@ def position_refinement(
741741 diffraction_patterns_projected = copy_to_device (
742742 self ._diffraction_patterns_projected [datacube_numbers [a0 ]], device
743743 )
744- error_shifts = np .zeros ((y_values .shape [0 ], 4 ))
744+ error_shifts = np .zeros ((y_values .shape [0 ] - 2 , 4 ))
745745 position_deltas = [(1 , 0 ), (- 1 , 0 ), (0 , 1 ), (0 , - 1 )]
746746
747- for a2 in range (y_values .shape [0 ]):
747+ for a2 in range (y_values .shape [0 ] - 2 ):
748748 object_sliced = self ._forward (
749749 datacube_number = datacube_numbers [a0 ],
750750 x_index = a2 ,
@@ -765,7 +765,7 @@ def position_refinement(
765765 object_sliced = object_sliced ,
766766 diffraction_patterns_projected = diffraction_patterns_projected ,
767767 datacube_number = datacube_numbers [a0 ],
768- x_index = a2 ,
768+ x_index = a2 + 1 ,
769769 )
770770
771771 error_shifts [a2 , a3 ] = error
@@ -793,8 +793,10 @@ def position_refinement(
793793 if max_total_displacement is not None :
794794 position_update = np .clip (
795795 position_update ,
796- - max_total_displacement - self ._position_refinements [datacube_numbers [a0 ]],
797- max_total_displacement - self ._position_refinements [datacube_numbers [a0 ]],
796+ - max_total_displacement
797+ - self ._position_refinements [datacube_numbers [a0 ]],
798+ max_total_displacement
799+ - self ._position_refinements [datacube_numbers [a0 ]],
798800 )
799801
800802 x_vox = positions_save [0 ].copy () + position_update [0 ]
@@ -1772,6 +1774,7 @@ def _calculate_update(
17721774 update = xp .zeros (object_sliced .shape )
17731775 error = 0
17741776 error = copy_to_device (error , "cpu" )
1777+ dp_patterns_counted = np .asarray ([0 ])
17751778
17761779 else :
17771780 weights = np .hstack (
@@ -1826,6 +1829,7 @@ def _calculate_update(
18261829 ).reshape ((s [1 ], dp_length ))
18271830
18281831 update = dp_patterns_counted - object_sliced
1832+ update [dp_patterns_counted .sum (1 ) == 0 ] = 0
18291833
18301834 error = (
18311835 xp .mean (update .ravel () ** 2 ) ** 0.5 / dp_patterns_counted .mean (0 ).sum ()
0 commit comments