Skip to content

Commit d78b3f5

Browse files
committed
more position updates
1 parent 75accfa commit d78b3f5

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

py4DSTEM/tomography/tomography.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -741,10 +741,10 @@ def position_refinement(
741741
diffraction_patterns_projected = copy_to_device(
742742
self._diffraction_patterns_projected[datacube_numbers[a0]], device
743743
)
744-
error_shifts = np.zeros((y_values.shape[0], 4))
744+
error_shifts = np.zeros((y_values.shape[0] - 2, 4))
745745
position_deltas = [(1, 0), (-1, 0), (0, 1), (0, -1)]
746746

747-
for a2 in range(y_values.shape[0]):
747+
for a2 in range(y_values.shape[0] - 2):
748748
object_sliced = self._forward(
749749
datacube_number=datacube_numbers[a0],
750750
x_index=a2,
@@ -765,7 +765,7 @@ def position_refinement(
765765
object_sliced=object_sliced,
766766
diffraction_patterns_projected=diffraction_patterns_projected,
767767
datacube_number=datacube_numbers[a0],
768-
x_index=a2,
768+
x_index=a2 + 1,
769769
)
770770

771771
error_shifts[a2, a3] = error
@@ -793,8 +793,10 @@ def position_refinement(
793793
if max_total_displacement is not None:
794794
position_update = np.clip(
795795
position_update,
796-
-max_total_displacement - self._position_refinements[datacube_numbers[a0]],
797-
max_total_displacement - self._position_refinements[datacube_numbers[a0]],
796+
-max_total_displacement
797+
- self._position_refinements[datacube_numbers[a0]],
798+
max_total_displacement
799+
- self._position_refinements[datacube_numbers[a0]],
798800
)
799801

800802
x_vox = positions_save[0].copy() + position_update[0]
@@ -1772,6 +1774,7 @@ def _calculate_update(
17721774
update = xp.zeros(object_sliced.shape)
17731775
error = 0
17741776
error = copy_to_device(error, "cpu")
1777+
dp_patterns_counted = np.asarray([0])
17751778

17761779
else:
17771780
weights = np.hstack(
@@ -1826,6 +1829,7 @@ def _calculate_update(
18261829
).reshape((s[1], dp_length))
18271830

18281831
update = dp_patterns_counted - object_sliced
1832+
update[dp_patterns_counted.sum(1) == 0] = 0
18291833

18301834
error = (
18311835
xp.mean(update.ravel() ** 2) ** 0.5 / dp_patterns_counted.mean(0).sum()

0 commit comments

Comments
 (0)