diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4ecfc2eb..0de6054d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,14 +18,14 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.9.10 hooks: - id: ruff args: [ --config=pyproject.toml ] - id: ruff-format args: [ --config=pyproject.toml ] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.13.0 + rev: v1.15.0 hooks: - id: mypy additional_dependencies: @@ -39,8 +39,9 @@ repos: additional_dependencies: [setuptools-scm, wheel] - repo: https://github.com/codespell-project/codespell # Configuration for codespell is in pyproject.toml - rev: v2.3.0 + rev: v2.4.1 hooks: - id: codespell additional_dependencies: - tomli + args: ["--ignore-words-list=ptd"] diff --git a/derotation/analysis/bayesian_optimization.py b/derotation/analysis/bayesian_optimization.py new file mode 100644 index 00000000..01008354 --- /dev/null +++ b/derotation/analysis/bayesian_optimization.py @@ -0,0 +1,115 @@ +import logging +from pathlib import Path +from typing import Tuple + +import matplotlib.pyplot as plt +import numpy as np +from bayes_opt import BayesianOptimization + +from derotation.analysis.mean_images import calculate_mean_images +from derotation.analysis.metrics import ptd_of_most_detected_blob +from derotation.derotate_by_line import derotate_an_image_array_line_by_line + + +class BO_for_derotation: + def __init__( + self, + movie: np.ndarray, + rot_deg_line: np.ndarray, + rot_deg_frame: np.ndarray, + blank_pixels_value: float, + center: Tuple[int, int], + delta: int, + init_points: int = 2, + n_iter: int = 10, + debug_plots_folder: Path = Path("./debug_plots"), + ): + """Initializes the BO_for_derotation class. + Bayesian optimization to find the best center of rotation. + + Parameters + ---------- + movie : np.ndarray + The calcium imaging movie to derotate. + rot_deg_line : np.ndarray + The rotation angles of the rotation plane. + rot_deg_frame : np.ndarray + The rotation angles of the frames. + blank_pixels_value : float + The value of the blank pixels. + center : Tuple[int, int] + The initial center of rotation. + delta : int + The variation allowed in the boundaries of the center of rotation. + init_points : int, optional + The number of initial points to evaluate, by default 2 + n_iter : int, optional + The number of iterations to run the optimization, by default 10 + debug_plots_folder : Path, optional + The folder to save the debugging plots, by default + Path("./debug_plots") + """ + self.movie = movie + self.rot_deg_line = rot_deg_line + self.rot_deg_frame = rot_deg_frame + self.blank_pixels_value = blank_pixels_value + x, y = center + self.pbounds = { + "x": (x - delta, x + delta), + "y": (y - delta, y + delta), + } + self.init_points = init_points + self.n_iter = n_iter + self.debug_plots_folder = debug_plots_folder + + self.subfolder = self.debug_plots_folder / "bo" + + def optimize(self): + def derotate_and_get_metric( + x: float, + y: float, + ): + derotated_chunk = derotate_an_image_array_line_by_line( + image_stack=self.movie, + rot_deg_line=self.rot_deg_line, + blank_pixels_value=self.blank_pixels_value, + center=(int(x), int(y)), + ) + + mean_images = calculate_mean_images( + derotated_chunk, self.rot_deg_frame, round_decimals=0 + ) + + plt.imshow(mean_images[0], cmap="gray") + plt.savefig(self.subfolder / f"mean_image_0_{x:.2f}_{y:.2f}.png") + plt.close() + + ptd = ptd_of_most_detected_blob( + mean_images, + debug_plots_folder=self.subfolder, + image_names=[ + f"blobs_{x:.2f}_{y:.2f}.png", + f"blob_centers_{x:.2f}_{y:.2f}.png", + ], + ) + + # we are maximizing the metric, so + # we need to return the negative of the metric + return -ptd + + optimizer = BayesianOptimization( + f=derotate_and_get_metric, + pbounds=self.pbounds, + verbose=2, + random_state=1, + ) + + optimizer.maximize( + init_points=self.init_points, + n_iter=self.n_iter, + ) + + for i, res in enumerate(optimizer.res): + logging.info(f"Iteration {i}: {res}") + + return optimizer.max diff --git a/derotation/analysis/blob_detection.py b/derotation/analysis/blob_detection.py index b4b538ae..5fa683db 100644 --- a/derotation/analysis/blob_detection.py +++ b/derotation/analysis/blob_detection.py @@ -96,7 +96,7 @@ def plot_blob_detection(self, blobs: list, image_stack: np.ndarray): fig, ax = plt.subplots(4, 3, figsize=(10, 10)) for i, a in tqdm(enumerate(ax.flatten())): a.imshow(image_stack[i]) - a.set_title(f"{i*5} degrees") + a.set_title(f"{i * 5} degrees") a.axis("off") for j, blob in enumerate(blobs[i][:4]): diff --git a/derotation/analysis/fit_ellipse.py b/derotation/analysis/fit_ellipse.py index 3c0cf93a..457fde49 100644 --- a/derotation/analysis/fit_ellipse.py +++ b/derotation/analysis/fit_ellipse.py @@ -10,6 +10,7 @@ def fit_ellipse_to_points( centers: np.ndarray, + pixels_in_row: int = 256, ) -> Tuple[int, int, int, int, int]: """Fit an ellipse to the points using least squares optimization. @@ -26,14 +27,19 @@ def fit_ellipse_to_points( """ # Convert centers to numpy array centers = np.array(centers) - x = centers[:, 0] - y = centers[:, 1] + valid_points = centers[ + ~np.isnan(centers).any(axis=1) + ] # Remove rows with NaN + if len(valid_points) < 5: + raise ValueError("Not enough valid points to fit an ellipse.") - # Find the extreme points for the initial ellipse estimate - topmost = centers[np.argmin(y)] - rightmost = centers[np.argmax(x)] - bottommost = centers[np.argmax(y)] - leftmost = centers[np.argmin(x)] + x, y = valid_points[:, 0], valid_points[:, 1] + + # Find extreme points for the initial ellipse estimate + topmost = valid_points[np.argmin(y)] + rightmost = valid_points[np.argmax(x)] + bottommost = valid_points[np.argmax(y)] + leftmost = valid_points[np.argmin(x)] # Initial parameters: (center_x, center_y, semi_major_axis, # semi_minor_axis, rotation_angle) @@ -42,8 +48,12 @@ def fit_ellipse_to_points( ) semi_major_axis = np.linalg.norm(rightmost - leftmost) / 2 semi_minor_axis = np.linalg.norm(topmost - bottommost) / 2 - rotation_angle = 0 # Start with no rotation + # Ensure axes are not zero + if semi_major_axis < 1e-3 or semi_minor_axis < 1e-3: + raise ValueError("Points are degenerate; cannot fit an ellipse.") + + rotation_angle = 0 # Start with no rotation initial_params = [ initial_center[0], initial_center[1], @@ -53,6 +63,7 @@ def fit_ellipse_to_points( ] logging.info("Fitting ellipse to points...") + logging.info(f"Initial parameters: {initial_params}") # Objective function to minimize: sum of squared distances to ellipse def ellipse_residuals(params, x, y): @@ -72,15 +83,26 @@ def ellipse_residuals(params, x, y): ellipse_residuals, initial_params, args=(x, y), - loss="huber", # minimize the influence of outliers + loss="huber", # Minimize the influence of outliers + bounds=( + # center_x, center_y, a, b, theta + [0, 0, 1e-3, 1e-3, -np.pi], + [ + pixels_in_row, + pixels_in_row, + pixels_in_row, + pixels_in_row, + np.pi, + ], + ), ) + if not result.success: + raise RuntimeError("Ellipse fitting did not converge.") + # Extract optimized parameters center_x, center_y, a, b, theta = result.x - # sometimes the fitted theta is a multiple of of 2pi - theta = theta % (2 * np.pi) - return center_x, center_y, a, b, theta diff --git a/derotation/analysis/full_derotation_pipeline.py b/derotation/analysis/full_derotation_pipeline.py index 6bbb03c2..067f308f 100644 --- a/derotation/analysis/full_derotation_pipeline.py +++ b/derotation/analysis/full_derotation_pipeline.py @@ -1,8 +1,9 @@ import copy +import itertools import logging import sys from pathlib import Path -from typing import Tuple +from typing import Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -10,10 +11,14 @@ import tifffile as tiff import yaml from fancylog import fancylog -from scipy.signal import find_peaks +from scipy.signal import butter, find_peaks, sosfilt +from sklearn.cluster import KMeans from sklearn.mixture import GaussianMixture -from tifffile import imsave +from tifffile import imwrite +from derotation.analysis.bayesian_optimization import BO_for_derotation +from derotation.analysis.mean_images import calculate_mean_images +from derotation.analysis.metrics import ptd_of_most_detected_blob from derotation.derotate_by_line import derotate_an_image_array_line_by_line from derotation.load_data.custom_data_loaders import ( get_analog_signals, @@ -27,7 +32,7 @@ class FullPipeline: """ ### ----------------- Main pipeline ----------------- ### - def __init__(self, config_name): + def __init__(self, _config: Union[dict, str]): """DerotationPipeline is a class that derotates an image stack acquired with a rotating sample under a microscope. In the constructor, it loads the config file, starts the logging @@ -38,10 +43,15 @@ def __init__(self, config_name): Parameters ---------- - config_name : str - Name of the config file without extension. + _config : Union[dict, str] + Name of the config file without extension that will be retrieved + in the derotation/config folder, or the config dictionary. """ - self.config = self.get_config(config_name) + if isinstance(_config, dict): + self.config = _config + else: + self.config = self.get_config(_config) + self.start_logging() self.load_data() @@ -56,15 +66,29 @@ def __call__(self): - saving the masked image stack """ self.process_analog_signals() + + self.offset = self.find_image_offset(self.image_stack[0]) + self.set_optimal_center() + rotated_images = self.derotate_frames_line_by_line() - masked = self.add_circle_mask(rotated_images, self.mask_diameter) - self.save(masked) + self.masked_image_volume = self.add_circle_mask( + rotated_images, self.mask_diameter + ) + self.mean_images = calculate_mean_images( + self.masked_image_volume, self.rot_deg_frame, round_decimals=0 + ) + self.metric = ptd_of_most_detected_blob( + self.mean_images, + plot=self.debugging_plots, + debug_plots_folder=self.debug_plots_folder, + ) + + self.save(self.masked_image_volume) self.save_csv_with_derotation_data() def get_config(self, config_name: str) -> dict: """Loads config file from derotation/config folder. Please edit it to change the parameters of the analysis. - Parameters ---------- config_name : str @@ -76,7 +100,10 @@ def get_config(self, config_name: str) -> dict: dict Config dictionary. """ - path_config = "derotation/config/" + config_name + ".yml" + + path_config = ( + Path(__file__).parent.parent / f"config/{config_name}.yml" + ) with open(Path(path_config), "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -95,6 +122,8 @@ def start_logging(self): filename="derotation", verbose=False, ) + # suppress debug messages from matplotlib + logging.getLogger("matplotlib.font_manager").setLevel(logging.ERROR) def load_data(self): """Loads data from the paths specified in the config file. @@ -119,7 +148,7 @@ def load_data(self): data from your setup. """ logging.info("Loading data...") - + logging.info(f"Loading {self.config['paths_read']['path_to_tif']}") self.image_stack = tiff.imread( self.config["paths_read"]["path_to_tif"] ) @@ -132,6 +161,15 @@ def load_data(self): self.direction, self.speed = read_randomized_stim_table( self.config["paths_read"]["path_to_randperm"] ) + logging.info(f"Number of rotations: {len(self.direction)}") + + rotation_direction = pd.DataFrame( + {"direction": self.direction, "speed": self.speed} + ).pivot_table( + index="direction", columns="speed", aggfunc="size", fill_value=0 + ) + # print pivot table + logging.info(f"Rotation direction: \n{rotation_direction}") self.number_of_rotations = len(self.direction) @@ -157,7 +195,13 @@ def load_data(self): self.config["paths_read"]["path_to_tif"] ).stem.split(".")[0] self.filename = self.config["paths_write"]["saving_name"] - + Path(self.config["paths_write"]["derotated_tiff_folder"]).mkdir( + parents=True, exist_ok=True + ) + self.file_saving_path_with_name = ( + Path(self.config["paths_write"]["derotated_tiff_folder"]) + / self.filename + ) self.std_coef = self.config["analog_signals_processing"][ "squared_pulse_k" ] @@ -167,21 +211,44 @@ def load_data(self): self.debugging_plots = self.config["debugging_plots"] + self.frame_rate = self.config["frame_rate"] + if self.debugging_plots: self.debug_plots_folder = Path( self.config["paths_write"]["debug_plots_folder"] ) - self.debug_plots_folder.mkdir(parents=True, exist_ok=True) + Path(self.debug_plots_folder).mkdir(parents=True, exist_ok=True) + + # unlink previous debug plots + logging.info("Deleting previous debug plots...") + for item in self.debug_plots_folder.iterdir(): + if item.is_dir(): + for file in item.iterdir(): + if file.suffix == ".png": + file.unlink() + else: + if item.suffix == ".png": + item.unlink() logging.info(f"Dataset {self.filename_raw} loaded") logging.info(f"Filename: {self.filename}") # by default the center of rotation is the center of the image - self.center_of_rotation = ( - self.num_lines_per_frame // 2, - self.num_lines_per_frame // 2, - ) + if not self.config["biased_center"]: + self.center_of_rotation = ( + self.num_lines_per_frame // 2, + self.num_lines_per_frame // 2, + ) + else: + self.center_of_rotation = tuple(self.config["biased_center"]) + self.hooks = {} + self.rotation_plane_angle = 0 + self.rotation_plane_orientation = 0 + + self.delta = self.config["delta_center"] + self.init_points = self.config["init_points"] + self.n_iter = self.config["n_iter"] ### ----------------- Analog signals processing pipeline ------------- ### def process_analog_signals(self): @@ -228,7 +295,8 @@ def process_analog_signals(self): self.line_start, self.line_end, ) = self.get_start_end_times_with_threshold( - self.line_clock, self.std_coef + self.line_clock, + self.std_coef, ) ( self.frame_start, @@ -244,7 +312,8 @@ def process_analog_signals(self): if self.debugging_plots: self.plot_rotation_on_and_ticks() - self.plot_rotation_angles() + self.plot_rotation_angles_and_velocity() + self.plot_rotation_speeds() logging.info("✨ Analog signals processed ✨") @@ -441,14 +510,41 @@ def check_number_of_rotations(self): "Start and end of rotations have different lengths" ) if self.rot_blocks_idx["start"].shape[0] != self.number_of_rotations: - logging.info( - f"Number of rotations is {self.number_of_rotations}." - + f"Adjusting to {self.rot_blocks_idx['start'].shape[0]}" + logging.warning( + f"Number of rotations is not {self.number_of_rotations}." + + f"Found {self.rot_blocks_idx['start'].shape[0]} rotations." + + "Adjusting starting and ending times..." ) - self.number_of_rotations = self.rot_blocks_idx["start"].shape[0] + self.find_missing_rotation_on_periods() logging.info("Number of rotations is as expected") + def find_missing_rotation_on_periods(self): + """ + Find the missing rotation on periods by looking at the rotation ticks + and the rotation on signal. This is useful when the number of rotations + is not as expected. + Uses k-means to cluster the ticks and pick the first and last for each + cluster. These are the starting and ending times. + """ + + kmeans = KMeans( + n_clusters=self.number_of_rotations, random_state=0 + ).fit(self.rotation_ticks_peaks.reshape(-1, 1)) + + new_start = np.zeros(self.number_of_rotations, dtype=int) + new_end = np.zeros(self.number_of_rotations, dtype=int) + for i in range(self.number_of_rotations): + new_start[i] = self.rotation_ticks_peaks[kmeans.labels_ == i].min() + new_end[i] = self.rotation_ticks_peaks[kmeans.labels_ == i].max() + + # cluster number is not the same as the number of rotations + self.rot_blocks_idx["start"] = sorted(new_start) + self.rot_blocks_idx["end"] = sorted(new_end) + + # update the rotation on signal + self.rotation_on = self.create_signed_rotation_array() + def is_number_of_ticks_correct(self) -> bool: """Compares the total number of ticks with the expected number of ticks, which is calculated from the number of rotations and the @@ -610,7 +706,11 @@ def check_rotation_number_after_interpolation( "Start and end of rotations have different lengths" ) if start.shape[0] != self.number_of_rotations: - raise ValueError( + # plot the rotation on signal and the interpolated angles + # this is useful to debug the interpolation + self.plot_rotation_on_and_ticks() + + logging.warning( "Number of rotations is not as expected after interpolation, " + f"{start.shape[0]} instead of {self.number_of_rotations}" ) @@ -692,6 +792,29 @@ def clock_to_latest_rotation_start(self, clock_time: int) -> int: """ return np.where(self.rot_blocks_idx["start"] < clock_time)[0][-1] + def calculate_velocity(self): + """Calculates the velocity of the rotation by line.""" + # Compute the correct sampling rate + self.sampling_rate = ( + self.frame_rate * self.num_lines_per_frame + ) # 1725.44 Hz + + # Unwrap angles and compute velocity + warr = np.rad2deg(np.unwrap(np.deg2rad(self.rot_deg_line))) + velocity = np.diff(warr) * self.sampling_rate + + # Butterworth low-pass filter + order = 3 + nyq = 0.5 * self.sampling_rate # Nyquist frequency + cutoff = 10 / nyq # Normalized cutoff frequency + + sos = butter( + order, cutoff, btype="low", output="sos" + ) # Use 'sos' for stability + filtered = sosfilt(sos, velocity) # Apply filter correctly + + return filtered + def plot_rotation_on_and_ticks(self): """Plots the rotation ticks and the rotation on signal. This plot will be saved in the debug_plots folder. @@ -732,25 +855,29 @@ def plot_rotation_on_and_ticks(self): plt.savefig( self.debug_plots_folder / "rotation_ticks_and_rotation_on.png" ) + plt.close() - def plot_rotation_angles(self): + def plot_rotation_angles_and_velocity(self): """Plots example rotation angles by line and frame for each speed. - This plot will be saved in the debug_plots folder. - Please inspect it to check that the rotation angles are correctly - calculated. + The velocity is also plotted on top of the rotation angles. + + This plot will be saved in the debug_plots folder. Please inspect it + to check that the rotation angles are correctly calculated. """ logging.info("Plotting rotation angles...") - fig, axs = plt.subplots(2, 2, figsize=(10, 10)) + fig, axs = plt.subplots(2, 2, figsize=(15, 10)) speeds = set(self.speed) - last_idx_for_each_speed = [ - np.where(self.speed == s)[0][-1] for s in speeds + first_idx_for_each_speed = [ + np.where(self.speed == s)[0][0] for s in speeds ] - last_idx_for_each_speed = sorted(last_idx_for_each_speed) + first_idx_for_each_speed = sorted(first_idx_for_each_speed) + + velocity = self.calculate_velocity() - for i, id in enumerate(last_idx_for_each_speed): + for i, id in enumerate(first_idx_for_each_speed): col = i // 2 row = i % 2 @@ -789,16 +916,147 @@ def plot_rotation_angles(self): + f" direction: {'cw' if self.direction[id] == 1 else 'ccw'}" ) + # plot velocity on top in red + ax2 = ax.twinx() + ax2.plot( + self.line_start[ + start_line_idx:end_line_idx + ], # Align x-axis with line_start + velocity[start_line_idx:end_line_idx] * -1, + color="gray", + label="velocity", + ) + + # remove top axis + ax2.spines["top"].set_visible(False) + + # set x label + ax.set_xlabel("Time (s)") + + # set y label left + ax.set_ylabel("Rotation angle (°)", color="black") + + # set y label right + ax2.set_ylabel("Velocity (°/s)", color="gray") + fig.suptitle("Rotation angles by line and frame") handles, labels = ax.get_legend_handles_labels() fig.legend(handles, labels, loc="upper right") plt.savefig(self.debug_plots_folder / "rotation_angles.png") + plt.close() + + def plot_rotation_speeds(self): + fig, ax = plt.subplots(2, 4, figsize=(15, 7)) + + unique_speeds = sorted(set(self.speed)) + unique_directions = sorted(set(self.direction)) + + velocity = self.calculate_velocity() + + # row clockwise, column speed + for i, (direction, speed) in enumerate( + itertools.product(unique_directions, unique_speeds) + ): + row = i // 4 + col = i % 4 + idx_this_speed = np.where( + np.logical_and( + self.speed == speed, self.direction == direction + ) + )[0] + + # linspace of colors depending on repetition number + colors = plt.cm.viridis(np.linspace(0, 1, len(idx_this_speed))) + + for j, idx in enumerate(idx_this_speed): + this_velocity = velocity[ + self.clock_to_latest_line_start( + self.rot_blocks_idx["start"][idx] + ) : self.clock_to_latest_line_start( + self.rot_blocks_idx["end"][idx] + ) + ] + ax[row, col].plot( + np.linspace( + 0, + len(this_velocity) * self.sampling_rate, + len(this_velocity), + ), + this_velocity, + label=f"repetition {idx}", + color=colors[j], + ) + ax[row, col].set_title(f"Speed: {speed}, direction: {direction}") + ax[row, col].spines["top"].set_visible(False) + ax[row, col].spines["right"].set_visible(False) + + # set titles of axis + ax[row, col].set_xlabel("Time (s)") + ax[row, col].set_ylabel("Velocity (°/s)") + + # leave more space between subplots + plt.subplots_adjust(hspace=0.5, wspace=0.5) + + fig.suptitle("Rotation on signal for each speed") + plt.savefig(self.debug_plots_folder / "all_speeds.png") + plt.close() ### ----------------- Derotation ----------------- ### - def plot_max_projection_with_center(self): + def find_optimal_parameters(self): + logging.info("Finding optimal parameters...") + + bo = BO_for_derotation( + self.image_stack, + self.rot_deg_line, + self.rot_deg_frame, + self.offset, + self.center_of_rotation, + self.delta, + self.init_points, + self.n_iter, + self.debug_plots_folder, + ) + + maximum = bo.optimize() + + logging.info(f"Optimal parameters: {maximum}") + logging.info(f"Target: {maximum['target']}") + + if maximum["target"] > -50: + # Consider a value of -50 as a threshold for the quality of the fit + logging.info("Using fitted center of rotation...") + x_center, y_center = maximum["params"].values() + self.center_of_rotation = (x_center, y_center) + + # write optimal center in a text file + with open( + self.debug_plots_folder / "optimal_center_of_rotation.txt", "w" + ) as f: + f.write(f"Optimal center of rotation: {x_center}, {y_center}") + + def set_optimal_center(self): + """Checks if the optimal center of rotation is calculated. + If it is not calculated, it will calculate it. + """ + try: + with open( + self.debug_plots_folder / "optimal_center_of_rotation.txt", "r" + ) as f: + optimal_center = f.read() + self.center_of_rotation = tuple( + map(float, optimal_center.split(":")[1].split(",")) + ) + logging.info("Optimal center of rotation found.") + except FileNotFoundError: + logging.info("Optimal center of rotation not found.") + self.find_optimal_parameters() + + def plot_max_projection_with_center( + self, stack, name="max_projection_with_center" + ): """Plots the maximum projection of the image stack with the center of rotation. This plot will be saved in the debug_plots folder. @@ -807,7 +1065,7 @@ def plot_max_projection_with_center(self): """ logging.info("Plotting max projection with center...") - max_projection = np.max(self.image_stack, axis=0) + max_projection = np.max(stack, axis=0) fig, ax = plt.subplots(1, 1, figsize=(5, 5)) @@ -824,7 +1082,8 @@ def plot_max_projection_with_center(self): ax.axis("off") - plt.savefig(self.debug_plots_folder / "max_projection_with_center.png") + plt.savefig(str(self.debug_plots_folder / name) + ".png") + plt.close() def derotate_frames_line_by_line(self) -> np.ndarray: """Wrapper for the function `derotate_an_image_array_line_by_line`. @@ -839,14 +1098,15 @@ def derotate_frames_line_by_line(self) -> np.ndarray: logging.info("Starting derotation by line...") if self.debugging_plots: - self.plot_max_projection_with_center() - - offset = self.find_image_offset(self.image_stack[0]) + self.plot_max_projection_with_center(self.image_stack) + # By default rotation_plane_angle and rotation_plane_orientation are 0 + # they have to be overwritten before calling the function. + # To calculate them please use the ellipse fit. rotated_image_stack = derotate_an_image_array_line_by_line( self.image_stack, self.rot_deg_line, - blank_pixels_value=offset, + blank_pixels_value=self.offset, center=self.center_of_rotation, plotting_hook_line_addition=self.hooks.get( "plotting_hook_line_addition" @@ -854,8 +1114,18 @@ def derotate_frames_line_by_line(self) -> np.ndarray: plotting_hook_image_completed=self.hooks.get( "plotting_hook_image_completed" ), + use_homography=self.rotation_plane_angle != 0, + rotation_plane_angle=self.rotation_plane_angle, + rotation_plane_orientation=self.rotation_plane_orientation, ) + if self.debugging_plots: + self.plot_max_projection_with_center( + rotated_image_stack, + name="derotated_max_projection_with_center", + ) + self.mean_image_for_each_rotation(rotated_image_stack) + logging.info("✨ Image stack rotated ✨") return rotated_image_stack @@ -892,6 +1162,23 @@ def find_image_offset(img): offset = np.min(gm.means_) return offset + def mean_image_for_each_rotation(self, rotated_image_stack): + folder = self.debug_plots_folder / "mean_images" + Path(folder).mkdir(parents=True, exist_ok=True) + for i, (start, end) in enumerate( + zip(self.rot_blocks_idx["start"], self.rot_blocks_idx["end"]) + ): + frame_start = self.clock_to_latest_frame_start(start) + frame_end = self.clock_to_latest_frame_start(end) + mean_image = np.mean( + rotated_image_stack[frame_start:frame_end], axis=0 + ) + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + ax.imshow(mean_image, cmap="viridis") + ax.axis("off") + plt.savefig(str(folder / f"mean_image_rotation_{i}.png")) + plt.close() + ### ----------------- Saving ----------------- ### @staticmethod def add_circle_mask( @@ -971,14 +1258,11 @@ def save(self, masked: np.ndarray): masked : np.ndarray The masked derotated image stack. """ - path = self.config["paths_write"]["derotated_tiff_folder"] - Path(path).mkdir(parents=True, exist_ok=True) - - imsave( - path + self.config["paths_write"]["saving_name"] + ".tif", + imwrite( + str(self.file_saving_path_with_name) + ".tif", np.array(masked), ) - logging.info(f"Masked image saved in {path}") + logging.info(f"Saving {str(self.file_saving_path_with_name) + '.tif'}") def save_csv_with_derotation_data(self): """Saves a csv file with the rotation angles by line and frame, @@ -994,8 +1278,27 @@ def save_csv_with_derotation_data(self): ) df["frame"] = np.arange(self.num_frames) - df["rotation_angle"] = self.rot_deg_frame[: self.num_frames] - df["clock"] = self.frame_start[: self.num_frames] + if len(self.rot_deg_frame) > self.num_frames: + df["rotation_angle"] = self.rot_deg_frame[: self.num_frames] + df["clock"] = self.frame_start[: self.num_frames] + logging.warning( + "Number of rotation angles by frame is higher than the" + " number of frames" + ) + elif len(self.rot_deg_frame) < self.num_frames: + missing_frames = self.num_frames - len(self.rot_deg_frame) + df["rotation_angle"] = np.append( + self.rot_deg_frame, [0] * missing_frames + ) + df["clock"] = np.append(self.frame_start, [0] * missing_frames) + + logging.warning( + "Number of rotation angles by frame is lower than the" + " number of frames. Adjusted." + ) + else: + df["rotation_angle"] = self.rot_deg_frame + df["clock"] = self.frame_start df["direction"] = np.nan * np.ones(len(df)) df["speed"] = np.nan * np.ones(len(df)) @@ -1017,13 +1320,7 @@ def save_csv_with_derotation_data(self): rotation_counter += 1 adding_roatation = False - Path(self.config["paths_write"]["derotated_tiff_folder"]).mkdir( - parents=True, exist_ok=True - ) - df.to_csv( - self.config["paths_write"]["derotated_tiff_folder"] - + self.config["paths_write"]["saving_name"] - + ".csv", + str(self.file_saving_path_with_name) + ".csv", index=False, ) diff --git a/derotation/analysis/incremental_derotation_pipeline.py b/derotation/analysis/incremental_derotation_pipeline.py index fff72456..9c7be330 100644 --- a/derotation/analysis/incremental_derotation_pipeline.py +++ b/derotation/analysis/incremental_derotation_pipeline.py @@ -1,4 +1,3 @@ -import copy import logging from pathlib import Path from typing import Dict, Tuple @@ -15,6 +14,7 @@ plot_ellipse_fit_and_centers, ) from derotation.analysis.full_derotation_pipeline import FullPipeline +from derotation.analysis.mean_images import calculate_mean_images class IncrementalPipeline(FullPipeline): @@ -46,7 +46,9 @@ def __call__(self): derotated_images = self.deroatate_by_frame() masked_unregistered = self.add_circle_mask(derotated_images) - mean_images = self.calculate_mean_images(masked_unregistered) + mean_images = calculate_mean_images( + masked_unregistered, self.rot_deg_frame + ) target_image = self.get_target_image(masked_unregistered) shifts = self.get_shifts_using_phase_cross_correlation( mean_images, target_image @@ -221,7 +223,7 @@ def check_rotation_number(self, start: np.ndarray, end: np.ndarray): if start.shape[0] != 1: raise ValueError("Number of rotations is not as expected") - def plot_rotation_angles(self): + def plot_rotation_angles_and_velocity(self): """Plots example rotation angles by line and frame for each speed. This plot will be saved in the debug_plots folder. Please inspect it to check that the rotation angles are correctly @@ -255,38 +257,6 @@ def plot_rotation_angles(self): / "rotation_angles.png" ) - def calculate_mean_images(self, image_stack: np.ndarray) -> list: - """Calculate the mean images for each rotation angle. This required - to calculate the shifts using phase cross correlation. - - Parameters - ---------- - rotated_image_stack : np.ndarray - The rotated image stack. - - Returns - ------- - list - The list of mean images. - """ - logging.info("Calculating mean images...") - - # correct for a mismatch in the total number of frames - # and the number of angles, given by instrument error - angles_subset = copy.deepcopy(self.rot_deg_frame[2:]) - # also there is a bias on the angles - angles_subset += -0.1 - rounded_angles = np.round(angles_subset, 2) - - mean_images = [] - for i in np.arange(10, 360, 10): - images = image_stack[rounded_angles == i] - mean_image = np.mean(images, axis=0) - - mean_images.append(mean_image) - - return mean_images - def save_csv_with_derotation_data(self): """Saves a csv file with the rotation angles by line and frame, and the rotation on signal. @@ -483,7 +453,9 @@ def find_center_of_rotation(self) -> Tuple[int, int]: "Fitting an ellipse to the largest blob centers " + "to find the center of rotation..." ) - mean_images = self.calculate_mean_images(self.image_stack) + mean_images = calculate_mean_images( + self.image_stack, self.rot_deg_frame + ) logging.info("Finding blobs...") bd = BlobDetection(self.debugging_plots, self.debug_plots_folder) @@ -494,7 +466,8 @@ def find_center_of_rotation(self) -> Tuple[int, int]: # Fit an ellipse to the largest blob centers and get its center center_x, center_y, a, b, theta = fit_ellipse_to_points( - coord_first_blob_of_every_image + coord_first_blob_of_every_image, + pixels_in_row=self.num_lines_per_frame, ) if self.debugging_plots: diff --git a/derotation/analysis/mean_images.py b/derotation/analysis/mean_images.py new file mode 100644 index 00000000..30895a48 --- /dev/null +++ b/derotation/analysis/mean_images.py @@ -0,0 +1,50 @@ +import copy +import logging + +import numpy as np + + +def calculate_mean_images( + image_stack: np.ndarray, rot_deg_frame: np.ndarray, round_decimals: int = 2 +) -> np.ndarray: + """Calculate the mean images for each rotation angle. This required + to calculate the shifts using phase cross correlation. + + Parameters + ---------- + rotated_image_stack : np.ndarray + The rotated image stack. + + Returns + ------- + np.ndarray + The mean images for each rotation angle. + """ + # correct for a mismatch in the total number of frames + angles_subset = copy.deepcopy(rot_deg_frame) + if len(angles_subset) > len(image_stack): + angles_subset = angles_subset[: len(image_stack)] + else: + image_stack = image_stack[: len(angles_subset)] + + assert len(image_stack) == len(angles_subset), ( + "Mismatch in the number of images and angles" + ) + + rounded_angles = np.round(angles_subset, round_decimals) + + mean_images = [] + for i in np.arange(10, 360, 10): + try: + images = image_stack[rounded_angles == i] + mean_image = np.mean(images, axis=0) + + mean_images.append(mean_image) + except IndexError as e: + logging.warning(f"Angle {i} not found in the image stack") + logging.warning(e) + + example_angles = np.random.choice(rounded_angles, 100) + logging.info(f"Example angles: {example_angles}") + + return np.asarray(mean_images) diff --git a/derotation/analysis/metrics.py b/derotation/analysis/metrics.py new file mode 100644 index 00000000..c0f22d61 --- /dev/null +++ b/derotation/analysis/metrics.py @@ -0,0 +1,132 @@ +from collections import Counter +from pathlib import Path +from typing import List + +import matplotlib.pyplot as plt +import numpy as np +from skimage.feature import blob_log +from sklearn.cluster import DBSCAN + + +def ptd_of_most_detected_blob( + mean_images_by_angle: np.ndarray, + plot: bool = True, + blob_log_kwargs: dict = { + "min_sigma": 7, + "max_sigma": 10, + "threshold": 0.95, + "overlap": 0, + }, + debug_plots_folder: Path = Path("/debug_plots"), + image_names: List[str] = [ + "detected_blobs.png", + "most_detected_blob_centers.png", + ], + DBSCAN_max_distance: float = 10.0, + clipping_percentiles: List[float] = [99.0, 99.99], +) -> float: + """Calculate the peak to peak distance of the centers of the most + detected blob in the derotated stack across all frames. + + Parameters + ---------- + mean_images_by_angle : np.ndarray + The derotated stack of images. + plot : bool, optional + Whether to plot the detected blobs, by default True + blob_log_kwargs : _type_, optional + The parameters for the blob detection algorithm, by default + { "min_sigma": 7, "max_sigma": 10, "threshold": 0.95, "overlap": 0, } + debug_plots_folder : str, optional + The folder to save the debugging plots, by default "/debug_plots" + image_names : List[str], optional + The names of the images to save if plot is True, by default + ["detected_blobs.png", "most_detected_blob_centers.png"] + DBSCAN_max_distance : float, optional + The maximum distance between two samples for one to be considered as + in the neighborhood of the other, by default 10.0 + clipping_percentiles : List[float], optional + The percentiles to clip the images to, by default [99.0, 99.99] + Returns + ------- + float + The peak to peak distance of the centers of the most detected blob. + """ + # clip all the images to the same contrast + clipped_images = [ + np.clip( + img, + np.percentile(img, clipping_percentiles[0]), + np.percentile(img, clipping_percentiles[1]), + ) + for img in mean_images_by_angle + ] + + # Detect the blobs in the derotated stack in each frame + # blobs is a list(list(x, y, sigma)) of the detected blobs for every frame + blobs = [ + blob_log( + img, + min_sigma=blob_log_kwargs["min_sigma"], + max_sigma=blob_log_kwargs["max_sigma"], + threshold=blob_log_kwargs["threshold"], + overlap=blob_log_kwargs["overlap"], + ) + for img in clipped_images + ] + + # plot image with center of blobs + if plot: + fig, ax = plt.subplots() + ax.imshow(clipped_images[0], cmap="gray") + for blob in blobs[0]: + y, x, r = blob + c = plt.Circle((x, y), r, color="red", linewidth=2, fill=False) + plt.gca().add_artist(c) + + # save + plt.savefig(debug_plots_folder / image_names[0]) + plt.close() + + # Flatten the blob list and add frame indices + _blobs = [] + for frame_idx, frame_blobs in enumerate(blobs): + for blob in frame_blobs: + _blobs.append([*blob, frame_idx]) + all_blobs = np.array(_blobs) + + # Use DBSCAN to cluster blobs based on proximity + + coords = all_blobs[:, :3] # x, y, radius + DBSCAN_max_distance = float(DBSCAN_max_distance) + clustering = DBSCAN( + eps=DBSCAN_max_distance, + min_samples=2, + ).fit(coords) + all_blobs = np.column_stack( + (all_blobs, clustering.labels_) + ) # Add cluster labels + + cluster_counts = Counter(all_blobs[:, -1]) # Cluster labels + most_detected_label = max(cluster_counts, key=lambda k: cluster_counts[k]) + + # Extract blobs belonging to the most detected cluster + most_detected_blobs = all_blobs[all_blobs[:, -1] == most_detected_label] + + # Calculate range (peak to peak) + ptp = np.ptp(most_detected_blobs[:, 0]) + np.ptp(most_detected_blobs[:, 1]) + + # plot the most detected blobs centers + if plot: + fig, ax = plt.subplots() + ax.imshow(clipped_images[0], cmap="gray") + for blob in most_detected_blobs: + y, x, *_ = blob + # plot an x on the center + plt.scatter(x, y, color="red", marker="x") + + # save + plt.savefig(debug_plots_folder / image_names[1]) + plt.close() + + return ptp diff --git a/derotation/config/full_rotation.yml b/derotation/config/full_rotation.yml index bac1bbb7..27823a8d 100644 --- a/derotation/config/full_rotation.yml +++ b/derotation/config/full_rotation.yml @@ -18,12 +18,14 @@ channel_names: [ "PI_rotticks", ] - rotation_increment: 0.2 +adjust_increment: True rot_deg: 360 debugging_plots: True +frame_rate: 6.74 + analog_signals_processing: find_rotation_ticks_peaks: height: 4 @@ -35,3 +37,8 @@ analog_signals_processing: interpolation: line_use_start: True frame_use_start: True + +biased_center: [129, 121] +delta_center: 7 +init_points: 2 +n_iter: 10 diff --git a/derotation/config/incremental_rotation.yml b/derotation/config/incremental_rotation.yml index 69e0c0d5..6cf327cf 100644 --- a/derotation/config/incremental_rotation.yml +++ b/derotation/config/incremental_rotation.yml @@ -24,6 +24,8 @@ rot_deg: 360 debugging_plots: True +frame_rate: 6.74 + analog_signals_processing: find_rotation_ticks_peaks: height: 4 diff --git a/derotation/derotate_batch.py b/derotation/derotate_batch.py new file mode 100644 index 00000000..6c0dae01 --- /dev/null +++ b/derotation/derotate_batch.py @@ -0,0 +1,75 @@ +import logging +import traceback +from pathlib import Path + +import yaml + +from derotation.analysis.full_derotation_pipeline import FullPipeline + + +def update_config_paths( + config, tif_path, bin_path, dataset_path, output_folder, kind="full" +): + # Set config paths based on provided arguments + config["paths_read"]["path_to_randperm"] = str( + Path(dataset_path).parent / "stimlus_randperm.mat" + ) + config["paths_read"]["path_to_aux"] = str(bin_path) + config["paths_read"]["path_to_tif"] = str(tif_path) + + # Set output paths to the specified output_folder + config["paths_write"]["debug_plots_folder"] = str( + Path(output_folder) / "derotation" / f"debug_plots_{kind}" + ) + config["paths_write"]["logs_folder"] = str( + Path(output_folder) / "derotation" / "logs" + ) + config["paths_write"]["derotated_tiff_folder"] = str( + Path(output_folder) / "derotation/" + ) + config["paths_write"]["saving_name"] = f"derotated_{kind}" + + return config + + +def derotate(dataset_folder: Path, output_folder): + this_module_path = Path(__file__).parent + + # FULL DEROTATION PIPELINE + # find tif and bin files + bin_path = list(dataset_folder.rglob("*rotation_*001.bin"))[0] + tif_path = list(dataset_folder.rglob("rotation_00001.tif"))[0] + + # Load the config template and update paths + + config_template_path = this_module_path / Path("config/full_rotation.yml") + with open(config_template_path, "r") as f: + config = yaml.safe_load(f) + + config = update_config_paths( + config, tif_path, bin_path, dataset_folder, output_folder, kind="full" + ) + + # Create output directories if they don't exist + Path(config["paths_write"]["debug_plots_folder"]).mkdir( + parents=True, exist_ok=True + ) + Path(config["paths_write"]["logs_folder"]).mkdir( + parents=True, exist_ok=True + ) + Path(config["paths_write"]["derotated_tiff_folder"]).mkdir( + parents=True, exist_ok=True + ) + + logging.info("Running full derotation pipeline") + + # Run the pipeline + try: + derotator = FullPipeline(config) + derotator() + return derotator.metric + except Exception as e: + logging.error("Full derotation pipeline failed") + logging.error(e.args) + logging.error(traceback.format_exc()) + raise e diff --git a/derotation/derotate_by_line.py b/derotation/derotate_by_line.py index b1afab16..2da9a170 100644 --- a/derotation/derotate_by_line.py +++ b/derotation/derotate_by_line.py @@ -306,9 +306,9 @@ def apply_homography( ) # check shape - assert ( - new_image_stack.shape == image_stack.shape - ), f"Shape mismatch: {new_image_stack.shape} != {image_stack.shape}" + assert new_image_stack.shape == image_stack.shape, ( + f"Shape mismatch: {new_image_stack.shape} != {image_stack.shape}" + ) image_stack = new_image_stack return image_stack diff --git a/derotation/simulate/synthetic_data.py b/derotation/simulate/synthetic_data.py index bdc5678f..a9d17882 100644 --- a/derotation/simulate/synthetic_data.py +++ b/derotation/simulate/synthetic_data.py @@ -376,15 +376,21 @@ def __init__(self): :: rotated_stack_incremental.shape[1] ][: rotated_stack_incremental.shape[0]] self.num_frames = rotated_stack_incremental.shape[0] + self.num_lines_per_frame = rotated_stack_incremental.shape[1] self.debugging_plots = make_plots self.debug_plots_folder = Path("debug/") - def calculate_mean_images(self, image_stack: np.ndarray) -> list: - # Overwrite original method as it is too bound + @staticmethod + def calculate_mean_images( + image_stack: np.ndarray, + rot_deg_frame: np.ndarray, + round_decimals: int = 0, + ) -> list: + # Override original method as it is too bound # to signal coming from a real motor - angles_subset = copy.deepcopy(self.rot_deg_frame) - rounded_angles = np.round(angles_subset) + angles_subset = copy.deepcopy(rot_deg_frame) + rounded_angles = np.round(angles_subset, round_decimals) mean_images = [] for i in np.arange(10, 360, 10): diff --git a/examples/derotation_slurm_job.py b/examples/derotation_slurm_job.py deleted file mode 100644 index 25c485c2..00000000 --- a/examples/derotation_slurm_job.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import sys -from pathlib import Path - -import yaml -from full_derotation_pipeline import FullPipeline -from incremental_derotation_pipeline import ( - IncrementalPipeline, -) - -# ===================================================================== -# Set up the config files -# ===================================================================== - -job_id = int(sys.argv[1:][0]) -dataset_path = sys.argv[1:][1] -datasets = [path for path in os.listdir(dataset_path) if path.startswith("23")] -dataset = datasets[job_id] - -bin_files = [ - file - for file in os.listdir(f"{dataset_path}/{dataset}/aux_stim/") - if file.endswith(".bin") -] -full_rotation_bin = [file for file in bin_files if "_rotation" in file][0] -incremental_bin = [file for file in bin_files if "increment" in file][0] - -image_files = [ - file - for file in os.listdir(f"{dataset_path}/{dataset}/imaging/") - if file.endswith(".tif") -] -full_rotation_image = [file for file in image_files if "rotation_0" in file][0] -incremental_image = [file for file in image_files if "increment_0" in file][0] - -Path(f"{dataset_path}/{dataset}/debug_plots_incremental/").mkdir( - parents=True, exist_ok=True -) -Path(f"{dataset_path}/{dataset}/debug_plots_full/").mkdir( - parents=True, exist_ok=True -) -Path(f"{dataset_path}/{dataset}/logs/").mkdir(parents=True, exist_ok=True) -Path(f"{dataset_path}/{dataset}/derotated/").mkdir(parents=True, exist_ok=True) - -for config_name in ["incremental_rotation", "full_rotation"]: - with open(f"derotation/config/{config_name}.yml") as f: - config = yaml.load(f, Loader=yaml.FullLoader) - - config["paths_read"]["path_to_randperm"] = ( - f"{dataset_path}/stimlus_randperm.mat" - ) - bin_name = ( - incremental_bin - if config_name == "incremental_rotation" - else full_rotation_bin - ) - config["paths_read"]["path_to_aux"] = ( - f"{dataset_path}/{dataset}/aux_stim/{bin_name}" - ) - image_name = ( - incremental_image - if config_name == "incremental_rotation" - else full_rotation_image - ) - config["paths_read"]["path_to_tif"] = ( - f"{dataset_path}/{dataset}/imaging/{image_name}" - ) - config["paths_write"]["debug_plots_folder"] = ( - f"{dataset_path}/{dataset}/debug_plots_{config_name.split('_')[0]}" - ) - config["paths_write"]["logs_folder"] = f"{dataset_path}/{dataset}/logs/" - config["paths_write"]["derotated_tiff_folder"] = ( - f"{dataset_path}/{dataset}/derotated/" - ) - config["paths_write"]["saving_name"] = ( - f"derotated_image_stack_{config_name.split('_')[0]}" - ) - - with open(f"derotation/config/{config_name}_{job_id}.yml", "w") as f: - yaml.dump(config, f) - - -# ===================================================================== -# Run the pipeline -# ===================================================================== - -derotate_incremental = IncrementalPipeline(f"incremental_rotation_{job_id}") -derotate_incremental() - -derotate_full = FullPipeline(f"full_rotation_{job_id}") -derotate_full.mask_diameter = derotate_incremental.new_diameter -derotate_full() diff --git a/pyproject.toml b/pyproject.toml index 226d550f..489734ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "tqdm", "scikit-learn", "scikit-image", + "bayesian-optimization", ] [project.urls] @@ -128,4 +129,5 @@ commands = [tool.codespell] skip = '.git' +ignore-words-list = "ptd" check-hidden = true diff --git a/tests/test_integration/test_derotation_with_simulated_data.py b/tests/test_integration/test_derotation_with_simulated_data.py index 29ca5e1b..9d485dd4 100644 --- a/tests/test_integration/test_derotation_with_simulated_data.py +++ b/tests/test_integration/test_derotation_with_simulated_data.py @@ -182,4 +182,4 @@ def assert_blob_detection( if __name__ == "__main__": Path("debug/").mkdir(parents=True, exist_ok=True) - test_derotation_with_rotation_out_of_plane((0, 0), 20, 45, plots=True) + test_derotation_with_rotation_out_of_plane((0, 0), 0, 0, plots=True) diff --git a/tests/test_integration/test_rotation_out_of_plane.py b/tests/test_integration/test_rotation_out_of_plane.py index 98373770..2ba070e4 100644 --- a/tests/test_integration/test_rotation_out_of_plane.py +++ b/tests/test_integration/test_rotation_out_of_plane.py @@ -114,7 +114,9 @@ def test_max_projection( # invert centers to (x, y) format centers = np.array([(x, y) for y, x in centers]) - xc, yc, a, b, orientation = fit_ellipse_to_points(centers) + xc, yc, a, b, orientation = fit_ellipse_to_points( + centers, pixels_in_row=rotated_image_stack.shape[1] + ) if exp_orientation is None and plane_angle == 0: # lower tolerance for the major and minor axes @@ -130,9 +132,9 @@ def test_max_projection( debug_plots_folder=Path("debug/"), saving_name=f"ellipse_fit_{plane_angle}_{exp_orientation}.png", ) - assert ( - False - ), f"Major and minor axes should be close, instead got {a} and {b}" + assert False, ( + f"Major and minor axes should be close,instead got {a} and {b}" + ) elif exp_orientation is not None: # Major axis orientation in clockwise direction as radians. exp_orientation = np.deg2rad(exp_orientation) diff --git a/tests/test_regression/test_derotation_by_line.py b/tests/test_regression/test_derotation_by_line.py index 25ed8990..c57402e9 100644 --- a/tests/test_regression/test_derotation_by_line.py +++ b/tests/test_regression/test_derotation_by_line.py @@ -48,6 +48,9 @@ def test_derotation_by_line(n_lines, n_total_lines, len_stack, image_stack): pipeline.center_of_rotation = (n_lines // 2, n_lines // 2) pipeline.hooks = {} pipeline.debugging_plots = False + pipeline.offset = pipeline.find_image_offset(pipeline.image_stack[0]) + pipeline.rotation_plane_angle = 0 + pipeline.rotation_plane_orientation = 0 derotated_images = pipeline.derotate_frames_line_by_line() diff --git a/tests/test_unit/test_adjust_rotation_increment.py b/tests/test_unit/test_adjust_rotation_increment.py index 47d3a516..1d3ead3b 100644 --- a/tests/test_unit/test_adjust_rotation_increment.py +++ b/tests/test_unit/test_adjust_rotation_increment.py @@ -16,9 +16,9 @@ def test_adjust_rotation_increment_360( new_increments = np.round(new_increments, 0) assert np.all(new_increments == corrected_increments), f"{new_increments}" - assert np.all( - new_ticks_per_rotation == ticks_per_rotation_calculated - ), f"{new_ticks_per_rotation}" + assert np.all(new_ticks_per_rotation == ticks_per_rotation_calculated), ( + f"{new_ticks_per_rotation}" + ) def test_adjust_rotation_increment_5( @@ -37,6 +37,6 @@ def test_adjust_rotation_increment_5( [0.139, 0.119, 0.152, 0.128, 0.147, 0.179, 0.2, 0.167, 0.139, 0.152] ) - assert np.all( - new_increments == correct_increments - ), f"new_increments: {new_increments}" + assert np.all(new_increments == correct_increments), ( + f"new_increments: {new_increments}" + )