diff --git a/hexrdgui/calibration/polarview.py b/hexrdgui/calibration/polarview.py index 60265a529..c0a4cbb18 100644 --- a/hexrdgui/calibration/polarview.py +++ b/hexrdgui/calibration/polarview.py @@ -55,9 +55,6 @@ def __init__(self, instrument, distortion_instrument=None): # Use an image dict with the panel buffers applied. # This keeps invalid pixels from bleeding out in the polar view self.images_dict = HexrdConfig().images_dict - # 0 is a better fill value because it results in fewer nans in - # the final image. - HexrdConfig().apply_panel_buffer_to_images(self.images_dict, 0) self.warp_dict = {} @@ -176,8 +173,16 @@ def images_dict(self): @images_dict.setter def images_dict(self, v): + # This images_dict sometimes gets modified by external callers, + # such as when a waterfall plot is created. So we need to make + # sure that everything that needs to be updated gets updated + # here. self._images_dict = v + # 0 is a better fill value because it results in fewer nans in + # the final image. + HexrdConfig().apply_panel_buffer_to_images(self._images_dict, 0) + # Cache the image min and max for later use self.min = min(x.min() for x in v.values()) self.max = max(x.max() for x in v.values()) @@ -240,13 +245,12 @@ def detector_borders(self, det): @property def all_detector_borders(self): borders = {} - for key in self.images_dict.keys(): + for key in self.detectors: borders[key] = self.detector_borders(key) return borders def create_warp_image(self, det): - # lcount = 0 img = self.images_dict[det] panel = self.detectors[det] @@ -508,7 +512,7 @@ def warp_all_images(self): self.reset_cached_distortion_fields() # Create the warped image for each detector - for det in self.images_dict.keys(): + for det in self.detectors: self.create_warp_image(det) # Generate the final image @@ -540,6 +544,9 @@ def update_detectors(self, detectors): self.generate_image() def reset_cached_distortion_fields(self): + # These are only reset so that other parts of the code + # will not use them while we are generating new ones. + # They are actually still cached elsewhere. HexrdConfig().polar_corr_field_polar = None HexrdConfig().polar_angular_grid = None diff --git a/hexrdgui/image_canvas.py b/hexrdgui/image_canvas.py index d1de09d53..83dce9bc2 100644 --- a/hexrdgui/image_canvas.py +++ b/hexrdgui/image_canvas.py @@ -4,8 +4,9 @@ import sys from PySide6.QtCore import QThreadPool, QTimer, Signal, Qt -from PySide6.QtWidgets import QFileDialog, QMessageBox +from PySide6.QtWidgets import QFileDialog, QMessageBox, QProgressDialog +from matplotlib.axes import Axes from matplotlib.backends.backend_qtagg import FigureCanvas from matplotlib.figure import Figure from matplotlib.lines import Line2D @@ -20,6 +21,7 @@ from hexrd import distortion as distortion_pkg +from hexrdgui import utils from hexrdgui.async_worker import AsyncWorker from hexrdgui.blit_manager import BlitManager from hexrdgui.calibration.cartesian_plot import cartesian_viewer @@ -33,12 +35,12 @@ from hexrdgui.masking.create_polar_mask import create_polar_line_data_from_raw from hexrdgui.masking.mask_manager import MaskManager from hexrdgui.snip_viewer_dialog import SnipViewerDialog -from hexrdgui import utils from hexrdgui.utils.array import split_array from hexrdgui.utils.conversions import ( angles_to_stereo, cart_to_angles, cart_to_pixels, q_to_tth, tth_to_q, ) from hexrdgui.utils.tth_distortion import apply_tth_distortion_if_needed +from hexrdgui.waterfall_plot import WaterfallPlotDialog # Increase these font sizes (compared to the global font) by the specified # amounts. @@ -52,6 +54,8 @@ class ImageCanvas(FigureCanvas): norm_modified = Signal() transform_modified = Signal() + _update_waterfall_plot_progress = Signal() + def __init__(self, parent=None, image_names=None): self.figure = Figure(tight_layout=True) super().__init__(self.figure) @@ -76,6 +80,7 @@ def __init__(self, parent=None, image_names=None): self.raw_view_images_dict = {} self._mask_boundary_artists = [] self._latest_compute_view_worker = None + self._waterfall_plot_dialog = None # Track the current mode so that we can more lazily clear on change. self.mode = None @@ -123,6 +128,14 @@ def setup_connections(self): HexrdConfig().panel_distortion_modified.connect( self.on_panel_distortion_changed) + # This *must* be a queued connection, because Mac requires the + # progress to be updated on the GUI thread. Otherwise, it will + # crash the application on Mac. + self._update_waterfall_plot_progress.connect( + self._update_waterfall_plot_progress_slot, + Qt.QueuedConnection, + ) + @property def thread_pool(self): return QThreadPool.globalInstance() @@ -1157,55 +1170,10 @@ def finish_show_polar(self, iviewer): HexrdConfig().last_unscaled_azimuthal_integral_data = unscaled self.azimuthal_integral_axis = axis - axis.set_ylabel(r'Azimuthal Average', **self.label_kwargs) self.update_azimuthal_plot_overlays() self.update_wppf_plot() - # Set up formatting for the x-axis - default_formatter = axis.xaxis.get_major_formatter() - f = self.format_polar_x_major_ticks - formatter = PolarXAxisFormatter(default_formatter, f) - axis.xaxis.set_major_formatter(formatter) - - axis.yaxis.set_major_locator(AutoLocator()) - axis.yaxis.set_minor_locator(AutoMinorLocator()) - - axis.xaxis.set_major_locator(PolarXAxisTickLocator(self)) - self.axis.xaxis.set_minor_locator( - PolarXAxisMinorTickLocator(self) - ) - - # change property of ticks - axis.tick_params(**self.major_tick_kwargs) - axis.tick_params(**self.minor_tick_kwargs) - - # add grid lines parallel to x-axis in azimuthal average - kwargs = { - 'visible': True, - 'which': 'major', - 'axis': 'y', - 'linewidth': 0.25, - 'linestyle': '-', - 'color': 'k', - 'alpha': 0.75, - } - axis.grid(**kwargs) - - kwargs = { - 'visible': True, - 'which': 'minor', - 'axis': 'y', - 'linewidth': 0.075, - 'linestyle': '--', - 'color': 'k', - 'alpha': 0.9, - } - axis.grid(**kwargs) - - # add grid lines parallel to y-axis - kwargs['which'] = 'both' - kwargs['axis'] = 'x' - axis.grid(**kwargs) + self._setup_azimuthal_axis(axis) else: self.update_azimuthal_integral_plot() axis = self.azimuthal_integral_axis @@ -1331,6 +1299,65 @@ def on_beam_energy_modified(self): # Update the beam energy on the instrument self.iviewer.instr.beam_energy = HexrdConfig().beam_energy + def _setup_azimuthal_axis(self, axis: Axes): + # Set the labels + axis.set_xlabel(self.polar_xlabel, **self.label_kwargs) + axis.set_ylabel(r'Azimuthal Average', **self.label_kwargs) + + # Set up formatting for the x-axis + # This is important in case "Q" is on the x axis instead + # of two theta. + default_formatter = axis.xaxis.get_major_formatter() + f = self.format_polar_x_major_ticks + formatter = PolarXAxisFormatter(default_formatter, f) + axis.xaxis.set_major_formatter(formatter) + + axis.yaxis.set_major_locator(AutoLocator()) + axis.yaxis.set_minor_locator(AutoMinorLocator()) + + axis.xaxis.set_major_locator(PolarXAxisTickLocator(self)) + self.axis.xaxis.set_minor_locator( + PolarXAxisMinorTickLocator(self) + ) + + # change property of ticks + axis.tick_params(**self.major_tick_kwargs) + axis.tick_params(**self.minor_tick_kwargs) + + # Set up the grids + # These are default kwargs for the grids. + default_kwargs = { + 'visible': True, + 'linewidth': 0.075, + 'linestyle': '--', + 'color': 'k', + 'alpha': 0.9, + } + + # Grid for minor y tickers + axis.grid(**{ + **default_kwargs, + 'which': 'minor', + 'axis': 'y', + 'linewidth': 0.25, + 'linestyle': '-', + 'alpha': 0.75, + }) + + # Grid for major y tickers + axis.grid(**{ + **default_kwargs, + 'which': 'major', + 'axis': 'y', + }) + + # Grid for all x tickers + axis.grid(**{ + **default_kwargs, + 'which': 'both', + 'axis': 'x', + }) + @property def polar_x_axis_type(self): return HexrdConfig().polar_x_axis_type @@ -1510,10 +1537,14 @@ def compute_azimuthal_integral_sum(self, scaled=True): pimg = self.scaled_images[0] else: pimg = self.unscaled_images[0] + + return self._compute_azimuthal_integral_sum(pimg) + + def _compute_azimuthal_integral_sum(self, pimg: np.ndarray) -> np.ndarray: # !!! NOTE: visible polar masks have already been applied # in polarview.py - masked = np.ma.masked_array(pimg, mask=np.isnan(pimg)) offset = HexrdConfig().azimuthal_offset + masked = np.ma.masked_array(pimg, mask=np.isnan(pimg)) return masked.sum(axis=0) / np.sum(~masked.mask, axis=0) + offset def clear_azimuthal_overlay_artists(self): @@ -1767,6 +1798,133 @@ def export_current_plot(self, filename): self.iviewer.write_image(filename) + def create_waterfall_plot(self): + if self.mode != ViewType.polar: + msg = 'Cannot create waterfall plot if we are not in polar mode' + raise Exception(msg) + + if not self.iviewer: + msg = 'Cannot create waterfall plot without an iviewer' + raise Exception(msg) + + if self._waterfall_plot_dialog is not None: + self._waterfall_plot_dialog.hide() + self._waterfall_plot_dialog = None + + # Determine the number of lineouts + num_lineouts = HexrdConfig().imageseries_length + + # Display a progress dialog indicating that we are + # generating intensities... + progress = QProgressDialog( + 'Generating azimuthal lineouts...', + None, + 0, + num_lineouts, + self, + ) + progress.setWindowTitle('HEXRD') + progress.setValue(1) + + # No close button in the corner + flags = progress.windowFlags() + progress.setWindowFlags( + (flags | Qt.CustomizeWindowHint) & + ~Qt.WindowCloseButtonHint + ) + + self._create_waterfall_progress = progress + + # Compute azimuthal lineouts in a background thread + worker = AsyncWorker(self._create_waterfall_lineouts) + self.thread_pool.start(worker) + self._latest_compute_view_worker = worker + + def on_finished(): + progress.reject() + + # Get the results and close the progress dialog when finished + worker.signals.result.connect(self._finish_create_waterfall) + worker.signals.finished.connect(on_finished) + + progress.exec() + + def _update_waterfall_plot_progress_slot(self): + progress = self._create_waterfall_progress + if progress is None: + return + + progress.setValue(progress.value() + 1) + + def _create_waterfall_lineouts(self) -> list[np.ndarray]: + # Determine the number of lineouts + num_lineouts = HexrdConfig().imageseries_length + lineouts = [None] * num_lineouts + + # We can already compute the lineout for the current frame + current_idx = HexrdConfig().current_imageseries_idx + lineouts[current_idx] = self.compute_azimuthal_integral_sum() + + # Make a deep copy of the iviewer, since we will modify it + iviewer = copy.deepcopy(self.iviewer) + + # Now generate the lineouts for the other frames + for i in range(num_lineouts): + if i == current_idx: + # We already generated this one + continue + + # Create the new imageseries dict + HexrdConfig().current_imageseries_idx = i + try: + new_images_dict = HexrdConfig().images_dict + finally: + # Always restore the previous index + HexrdConfig().current_imageseries_idx = current_idx + + # Now force the image dict to change + iviewer.pv.images_dict = new_images_dict + + # Generate the new image + iviewer.pv.warp_all_images() + + # Grab the new image + polar_img = iviewer.img + if HexrdConfig().polar_apply_scaling_to_lineout: + # Apply the transform + polar_img = self.transform(polar_img) + + # Compute the integration + lineouts[i] = self._compute_azimuthal_integral_sum(polar_img) + + # The progress must be updated in the GUI thread. Otherwise, + # it will crash on Mac. + self._update_waterfall_plot_progress.emit() + + return lineouts + + def _finish_create_waterfall(self, lineouts: list[np.ndarray]): + # Now create the waterfall plot dialog with the lineouts + # Create a matplotlib figure and set up everything + figure = plt.figure() + ax = figure.add_subplot() + + # Grab tth + angular_grid = self.iviewer.angular_grid + tth = np.degrees(angular_grid[1][0]) + line_data = [(tth, lineout.filled(np.nan)) for lineout in lineouts] + + # Set up the same azimuthal axes parameters as the polar view + self._setup_azimuthal_axis(ax) + + # Disable the tick labels + ax.set_yticklabels([]) + + # Now create and show the waterfall plot + dialog = WaterfallPlotDialog(ax, line_data) + dialog.show() + self._waterfall_plot_dialog = dialog + def export_to_maud(self, filename): if self.mode != ViewType.polar: msg = 'Must be in polar mode. Cannot export.' diff --git a/hexrdgui/image_mode_widget.py b/hexrdgui/image_mode_widget.py index ade930ad0..62fd7554e 100644 --- a/hexrdgui/image_mode_widget.py +++ b/hexrdgui/image_mode_widget.py @@ -8,6 +8,7 @@ from hexrdgui.constants import PolarXAxisType, ViewType from hexrdgui.create_hedm_instrument import create_hedm_instrument from hexrdgui.hexrd_config import HexrdConfig +from hexrdgui.image_load_manager import ImageLoadManager from hexrdgui.ui_loader import UiLoader from hexrdgui.utils import block_signals @@ -29,6 +30,9 @@ class ImageModeWidget(QObject): # Tell the image canvas to show the snip1d polar_show_snip1d = Signal() + # Tell the image canvas to create a waterfall plot + create_waterfall_plot = Signal() + raw_show_zoom_dialog = Signal() def __init__(self, parent=None): @@ -97,6 +101,8 @@ def setup_connections(self): self.on_polar_x_axis_type_changed) self.ui.polar_active_beam.currentIndexChanged.connect( self.on_active_beam_changed) + self.ui.create_waterfall_plot.clicked.connect( + self.on_create_waterfall_plot_clicked) HexrdConfig().instrument_config_loaded.connect( self.on_instrument_config_load) @@ -141,6 +147,9 @@ def setup_connections(self): self.ui.stereo_project_from_polar.toggled.connect( HexrdConfig().set_stereo_project_from_polar) + ImageLoadManager().new_images_loaded.connect( + self.update_visibility_states) + def enable_image_mode_widget(self, b): self.ui.tab_widget.setEnabled(b) @@ -256,6 +265,16 @@ def update_visibility_states(self): self.ui.polar_active_beam.setVisible(has_multi_xrs) self.ui.polar_active_beam_label.setVisible(has_multi_xrs) + # We can only make a waterfall plot if there is more than one + # frame in the imageseries. + # If there are more than 20, that's too many, and let's just ignore + # it as well. All of the cases we know of currently should have + # no more than 15 frames in the imageseries. + can_make_waterfall_plot = ( + 1 < HexrdConfig().imageseries_length <= 20 + ) + self.ui.create_waterfall_plot.setVisible(can_make_waterfall_plot) + def auto_generate_cartesian_params(self): if HexrdConfig().loading_state: # Don't modify the parameters if a state file is being @@ -480,6 +499,9 @@ def update_beam_names(self): def on_active_beam_changed(self): HexrdConfig().active_beam_name = self.ui.polar_active_beam.currentText() + def on_create_waterfall_plot_clicked(self): + self.create_waterfall_plot.emit() + POLAR_X_AXIS_LABELS_TO_VALUES = { '2θ': PolarXAxisType.tth, diff --git a/hexrdgui/image_tab_widget.py b/hexrdgui/image_tab_widget.py index 0256b0e56..10557c093 100644 --- a/hexrdgui/image_tab_widget.py +++ b/hexrdgui/image_tab_widget.py @@ -529,6 +529,9 @@ def export_current_plot(self, filename): def polar_show_snip1d(self): self.image_canvases[0].polar_show_snip1d() + def create_waterfall_plot(self): + self.image_canvases[0].create_waterfall_plot() + def export_to_maud(self, filename): self.image_canvases[0].export_to_maud(filename) diff --git a/hexrdgui/main_window.py b/hexrdgui/main_window.py index 81212681d..019f5ea6b 100644 --- a/hexrdgui/main_window.py +++ b/hexrdgui/main_window.py @@ -90,9 +90,6 @@ class MainWindow(QObject): - # Emitted when new images are loaded - new_images_loaded = Signal() - # Emitted when a new mask is added new_mask_added = Signal(str) @@ -273,7 +270,6 @@ def setup_connections(self): self.ui.action_run_fit_grains.triggered.connect( self.on_action_run_fit_grains_triggered) self.ui.action_run_wppf.triggered.connect(self.run_wppf) - self.new_images_loaded.connect(self.images_loaded) self.ui.image_tab_widget.update_needed.connect(self.update_all) self.ui.image_tab_widget.new_mouse_position.connect( self.new_mouse_position) @@ -314,6 +310,8 @@ def setup_connections(self): self.ui.image_tab_widget.polar_show_snip1d) self.image_mode_widget.raw_show_zoom_dialog.connect( self.on_show_raw_zoom_dialog) + self.image_mode_widget.create_waterfall_plot.connect( + self.ui.image_tab_widget.create_waterfall_plot) self.ui.action_open_images.triggered.connect( self.open_image_files) @@ -338,7 +336,7 @@ def setup_connections(self): self.on_physics_package_modified) ImageLoadManager().update_needed.connect(self.update_all) - ImageLoadManager().new_images_loaded.connect(self.new_images_loaded) + ImageLoadManager().new_images_loaded.connect(self.images_loaded) ImageLoadManager().images_transformed.connect(self.update_config_gui) ImageLoadManager().live_update_status.connect(self.set_live_update) ImageLoadManager().state_updated.connect( @@ -527,7 +525,8 @@ def load_dummy_images(self): ImageFileManager().load_dummy_images() self.update_all(clear_canvases=True) self.ui.action_transform_detectors.setEnabled(False) - self.new_images_loaded.emit() + # Manually indicate that new images were loaded + ImageLoadManager().new_images_loaded.emit() def open_image_file(self): images_dir = HexrdConfig().images_dir diff --git a/hexrdgui/resources/ui/image_mode_widget.ui b/hexrdgui/resources/ui/image_mode_widget.ui index 590b4d347..1aaca5820 100644 --- a/hexrdgui/resources/ui/image_mode_widget.ui +++ b/hexrdgui/resources/ui/image_mode_widget.ui @@ -883,6 +883,16 @@ + + + + <html><head/><body><p>Create a waterfall plot using the images in the image series.</p><p>This will first generate a polar view image for every index in the image series (which can be time-consuming depending on the polar resolution settings). Each polar view images is then used to create an azimuthal lineout.</p><p>The azimuthal lineouts for each of the polar view images is then plotted together in a waterfall plot dialog. This dialog allows plots to be adjusted as needed.</p></body></html> + + + Waterfall Plot + + + @@ -1051,8 +1061,26 @@ cartesian_plane_normal_rotate_x cartesian_virtual_plane_distance cartesian_plane_normal_rotate_y + polar_active_beam + polar_pixel_size_tth + polar_pixel_size_eta + polar_res_tth_min + polar_res_tth_max + polar_res_eta_min + polar_res_eta_max + polar_apply_snip1d + polar_snip1d_width + polar_snip1d_algorithm + polar_snip1d_numiter + polar_show_snip1d + polar_apply_erosion + polar_apply_tth_distortion + polar_tth_distortion_overlay polar_azimuthal_overlays azimuthal_offset + polar_x_axis_type + create_waterfall_plot + polar_apply_scaling_to_lineout stereo_size stereo_show_border stereo_project_from_polar diff --git a/hexrdgui/waterfall_plot.py b/hexrdgui/waterfall_plot.py new file mode 100644 index 000000000..6f5b7f808 --- /dev/null +++ b/hexrdgui/waterfall_plot.py @@ -0,0 +1,270 @@ +from typing import Callable + +from PySide6.QtCore import Qt +from PySide6.QtGui import QResizeEvent +from PySide6.QtWidgets import ( + QDialog, QLabel, QSizePolicy, QVBoxLayout, QWidget +) + +from matplotlib.axes import Axes +from matplotlib.backend_bases import FigureCanvasBase, KeyEvent, MouseEvent +from matplotlib.backends.backend_qtagg import ( + NavigationToolbar2QT as NavigationToolbar +) +from matplotlib.figure import Figure +import numpy as np + +# Line data is a list of (x, y) for each line +LineData = list[tuple[np.ndarray, np.ndarray]] + + +class WaterfallPlot: + """Waterfall Plot + + This class manages button clicks and key events for interactions + with a waterfall plot + """ + def __init__(self, ax: Axes, line_data: LineData): + self.ax = ax + self.create_lines(line_data) + + self.currently_dragging = None + self._prev_mouse_coords = None + self._shift_held = False + + self._mpl_cids = [] + + self.connect() + + def create_lines(self, line_data: LineData): + # Compute a default offset + offset = np.nanmax([np.nanmax(y) for _, y in line_data]) + offset *= 1.05 + + lines = [] + for i, (x, y) in enumerate(line_data): + lines.append(self.ax.plot( + x, + y + offset * i, + lw=2.5, + label=f'Frame {i + 1}', + )[0]) + + self.lines = lines + + self.ax.legend() + + # Also cache the line data for mouse interactions + cached_line_data = [] + for line in lines: + cached_line_data.append(np.array(line.get_data()).T) + self._cached_line_data = cached_line_data + + @property + def figure(self) -> Figure: + return self.ax.figure + + @property + def canvas(self) -> FigureCanvasBase: + return self.figure.canvas + + @property + def _mpl_callbacks(self) -> dict[str, Callable]: + return { + 'button_press_event': self.on_button_press, + 'button_release_event': self.on_button_release, + 'key_press_event': self.on_key_press, + 'key_release_event': self.on_key_release, + 'motion_notify_event': self.on_motion, + 'scroll_event': self.on_scroll, + } + + def on_button_press(self, event: MouseEvent): + if event.inaxes is not self.ax: + return + + if self.ax.get_navigate_mode() is not None: + # Zooming or panning is active. Ignore this click. + return + + coords_clicked = np.array((event.xdata, event.ydata)) + + # Find the closest line, and drag that one + closest_line_idx = self._find_closest_line(coords_clicked) + + self._prev_mouse_coords = coords_clicked + self.currently_dragging = closest_line_idx + + def on_button_release(self, event: MouseEvent): + self.currently_dragging = None + self._prev_mouse_coords = None + self.canvas.draw_idle() + + def on_key_press(self, event: KeyEvent): + if event.key == 'shift': + self._shift_held = True + + def on_key_release(self, event: KeyEvent): + if event.key == 'shift': + self._shift_held = False + + def on_motion(self, event: MouseEvent): + if self.currently_dragging is None or event.inaxes is not self.ax: + return + + mouse_coords = np.array((event.xdata, event.ydata)) + adjustment = mouse_coords - self._prev_mouse_coords + if not self._shift_held: + # If shift is not held, only allow y to vary + adjustment[0] = 0 + + data = self._cached_line_data[self.currently_dragging] + line = self.lines[self.currently_dragging] + data += adjustment + + line.set_data(data.T) + + # Rescale the axes + # Maybe this is something we want the user to be able to disable? + self.ax.relim() + self.ax.autoscale_view() + + # Redraw + self.canvas.draw_idle() + + self._prev_mouse_coords = mouse_coords + + def on_scroll(self, event: MouseEvent): + mouse_coords = np.array((event.xdata, event.ydata)) + + # Find the closest line, and drag that one + closest_line_idx = self._find_closest_line(mouse_coords) + + base_scale = 1.1 + if event.button == 'up': + # Increase the data intensity + scale_factor = base_scale + else: + # Decrease the data intensity + scale_factor = 1 / base_scale + + data = self._cached_line_data[closest_line_idx] + + # Don't allow the mean to change + mean_y = np.nanmean(data[:, 1]) + + data[:, 1] = (data[:, 1] - mean_y) * scale_factor + mean_y + + line = self.lines[closest_line_idx] + line.set_data(data.T) + + # Redraw + self.canvas.draw_idle() + + def _find_closest_line(self, coords: np.ndarray) -> int: + # Find the closest line to a set of coordinates and return + # the closest line index + min_distance = np.inf + closest_line_idx = -1 + for i, data in enumerate(self._cached_line_data): + distances = np.sqrt((data - coords)**2).sum(axis=1) + min_dist = np.nanmin(distances) + if min_dist < min_distance: + min_distance = min_dist + closest_line_idx = i + + return closest_line_idx + + def connect(self): + for k, f in self._mpl_callbacks.items(): + cid = self.canvas.mpl_connect(k, f) + self._mpl_cids.append(cid) + + def disconnect(self): + for cid in self._mpl_cids: + self.canvas.mpl_disconnect(cid) + + self._mpl_cids.clear() + + +class WaterfallPlotDialog(QDialog): + def __init__(self, ax: Axes, line_data: LineData, parent: QWidget = None): + super().__init__(parent) + + self.setWindowTitle('Waterfall Plot') + + # Add minimize, maximize, and close buttons + self.setWindowFlags( + Qt.WindowMinimizeButtonHint | + Qt.WindowMaximizeButtonHint | + Qt.WindowCloseButtonHint + ) + + self.waterfall_plot = WaterfallPlot(ax, line_data) + canvas = self.waterfall_plot.canvas + + layout = QVBoxLayout() + self.setLayout(layout) + + # Add a label describing the mouse interactions + label1 = QLabel( + 'Click and drag a plot to adjust Y. ' + 'Hold shift and then click and drag a plot to adjust both X and Y.' + ) + label2 = QLabel( + 'Hover mouse over a line and use the mouse wheel to rescale ' + 'the intensities of that line' + ) + for label in (label1, label2): + label.setAlignment(Qt.AlignCenter) + label.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum) + layout.addWidget(label) + + # Add the canvas + canvas.figure.tight_layout() + canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + layout.addWidget(canvas) + + # Add a navigation toolbar too + self.toolbar = NavigationToolbar(canvas, self) + self.toolbar.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Minimum) + layout.addWidget(self.toolbar) + layout.setAlignment(self.toolbar, Qt.AlignCenter) + + def resizeEvent(self, event: QResizeEvent): + # We override this function because we want the matplotlib canvas + # to also resize whenever the dialog is resized. + super().resizeEvent(event) + self.waterfall_plot.figure.tight_layout() + + +if __name__ == '__main__': + from PySide6.QtWidgets import QApplication + + import matplotlib.pyplot as plt + + app = QApplication() + + # Test example + fig, ax = plt.subplots() + data1 = np.load('example_integration.npy') + data2 = data1.copy() + data3 = data2.copy() + + line_data = [] + for data in (data1, data2, data3): + line_data.append((*data.T,)) + + label_kwargs = { + 'fontsize': 15, + 'family': 'serif', + } + ax.set_ylabel(r'Azimuthal Average', **label_kwargs) + + polar_xlabel = r'2$\theta_{{nom}}$ [deg]' + ax.set_xlabel(polar_xlabel, **label_kwargs) + + dialog = WaterfallPlotDialog(ax, line_data) + dialog.show() + + app.exec()