|
| 1 | +from collections.abc import Sequence |
| 2 | +from typing import TypeVar |
| 3 | + |
| 4 | +import jax |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import numpy as np |
| 7 | +import pyvista as pv |
| 8 | +from mpl_toolkits.axes_grid1 import make_axes_locatable |
| 9 | + |
| 10 | + |
| 11 | +def plot_mesh( |
| 12 | + mesh: dict, bounds: Sequence[float], save_path: str | None = None |
| 13 | +) -> None: |
| 14 | + """Plot a 3D triangular mesh with boundary conditions visualization. |
| 15 | +
|
| 16 | + Args: |
| 17 | + mesh: Dictionary containing 'points' and 'faces' arrays. |
| 18 | + save_path: Optional path to save the plot as an image file. |
| 19 | + bounds: bounds of the 3D space. |
| 20 | + """ |
| 21 | + Lx = bounds[0] |
| 22 | + Ly = bounds[1] |
| 23 | + Lz = bounds[2] |
| 24 | + |
| 25 | + fig = plt.figure(figsize=(10, 8)) |
| 26 | + ax = fig.add_subplot(111, projection="3d") |
| 27 | + ax.plot_trisurf( |
| 28 | + mesh["points"][:, 0], |
| 29 | + mesh["points"][:, 1], |
| 30 | + mesh["points"][:, 2], |
| 31 | + triangles=mesh["faces"], |
| 32 | + alpha=0.7, |
| 33 | + antialiased=True, |
| 34 | + color="lightblue", |
| 35 | + edgecolor="black", |
| 36 | + ) |
| 37 | + |
| 38 | + ax.set_xlim(-Lx / 2, Lx / 2) |
| 39 | + ax.set_ylim(-Ly / 2, Ly / 2) |
| 40 | + ax.set_zlim(-Lz / 2, Lz / 2) |
| 41 | + |
| 42 | + # set equal aspect ratio |
| 43 | + ax.set_box_aspect( |
| 44 | + ( |
| 45 | + (Lx) / (Ly), |
| 46 | + 1, |
| 47 | + (Lz) / (Ly), |
| 48 | + ) |
| 49 | + ) |
| 50 | + |
| 51 | + # x axis label |
| 52 | + ax.set_xlabel("X") |
| 53 | + ax.set_ylabel("Y") |
| 54 | + ax.set_zlabel("Z") |
| 55 | + |
| 56 | + if save_path: |
| 57 | + # avoid showing the plot in notebook |
| 58 | + plt.savefig(save_path) |
| 59 | + plt.close(fig) |
| 60 | + |
| 61 | + |
| 62 | +def plot_grid_slice(field_slice, extent, ax, title, xlabel, ylabel): |
| 63 | + im = ax.imshow(field_slice.T, extent=extent, origin="lower") |
| 64 | + ax.set_title(title) |
| 65 | + ax.set_xlabel(xlabel) |
| 66 | + ax.set_ylabel(ylabel) |
| 67 | + # add colorbar |
| 68 | + divider = make_axes_locatable(ax) |
| 69 | + cax = divider.append_axes("right", size="5%", pad=0.1) |
| 70 | + plt.colorbar(im, cax=cax, orientation="vertical") |
| 71 | + return im |
| 72 | + |
| 73 | + |
| 74 | +def plot_grid(field, Lx, Ly, Lz, Nx, Ny, Nz, title="SDF"): |
| 75 | + _, axs = plt.subplots(1, 3, figsize=(15, 5)) |
| 76 | + |
| 77 | + plot_grid_slice( |
| 78 | + field[Nx // 2, :, :], |
| 79 | + extent=(-Ly / 2, Ly / 2, -Lz / 2, Lz / 2), |
| 80 | + ax=axs[0], |
| 81 | + title=f"{title} slice at x=0", |
| 82 | + xlabel="y", |
| 83 | + ylabel="z", |
| 84 | + ) |
| 85 | + plot_grid_slice( |
| 86 | + field[:, Ny // 2, :], |
| 87 | + extent=(-Lx / 2, Lx / 2, -Lz / 2, Lz / 2), |
| 88 | + ax=axs[1], |
| 89 | + title=f"{title} slice at y=0", |
| 90 | + xlabel="x", |
| 91 | + ylabel="z", |
| 92 | + ) |
| 93 | + plot_grid_slice( |
| 94 | + field[:, :, Nz // 2], |
| 95 | + extent=(-Lx / 2, Lx / 2, -Ly / 2, Ly / 2), |
| 96 | + ax=axs[2], |
| 97 | + title=f"{title} slice at z=0", |
| 98 | + xlabel="x", |
| 99 | + ylabel="y", |
| 100 | + ) |
| 101 | + |
| 102 | + |
| 103 | +T = TypeVar("T") |
| 104 | + |
| 105 | + |
| 106 | +def stop_grads_int(x: T) -> T: |
| 107 | + """Stops gradient computation. |
| 108 | +
|
| 109 | + We cannot use jax.lax.stop_gradient directly because Tesseract meshes are |
| 110 | + nested dictionaries with arrays and integers, and jax.lax.stop_gradient |
| 111 | + does not support integers. |
| 112 | +
|
| 113 | + Args: |
| 114 | + x: Input value. |
| 115 | +
|
| 116 | + Returns: |
| 117 | + Value with stopped gradients. |
| 118 | + """ |
| 119 | + |
| 120 | + def stop(x): |
| 121 | + return jax._src.ad_util.stop_gradient_p.bind(x) |
| 122 | + |
| 123 | + return jax.tree_util.tree_map(stop, x) |
| 124 | + |
| 125 | + |
| 126 | +def hex_to_pyvista( |
| 127 | + pts: jax.typing.ArrayLike, faces: jax.typing.ArrayLike, cell_data: dict |
| 128 | +) -> pv.UnstructuredGrid: |
| 129 | + """Convert hex mesh defined by points and faces into a PyVista UnstructuredGrid. |
| 130 | +
|
| 131 | + Args: |
| 132 | + pts: Array of point coordinates, shape (N, 3). |
| 133 | + faces: Array of hexahedral cell connectivity, shape (M, 8). |
| 134 | + cell_data: additional cell center data. |
| 135 | +
|
| 136 | + Returns: |
| 137 | + PyVista mesh representing the hexahedral grid. |
| 138 | + """ |
| 139 | + pts = np.array(pts) |
| 140 | + faces = np.array(faces) |
| 141 | + |
| 142 | + # Define the cell type for hexahedrons (VTK_HEXAHEDRON = 12) |
| 143 | + cell_type = pv.CellType.HEXAHEDRON |
| 144 | + cell_types = np.array([cell_type] * faces.shape[0], dtype=np.uint8) |
| 145 | + |
| 146 | + # Prepare the cells array: [number_of_points, i0, i1, i2, i3, i4, i5, i6, i7] |
| 147 | + n_cells = faces.shape[0] |
| 148 | + cells = np.empty((n_cells, 9), dtype=np.int64) |
| 149 | + cells[:, 0] = 8 # Each cell has 8 points |
| 150 | + cells[:, 1:9] = faces |
| 151 | + |
| 152 | + # Flatten the cells array for PyVista |
| 153 | + cells = cells.flatten() |
| 154 | + |
| 155 | + mesh = pv.UnstructuredGrid(cells, cell_types, pts) |
| 156 | + |
| 157 | + # Add cell data |
| 158 | + for name, data in cell_data.items(): |
| 159 | + mesh.cell_data[name] = data |
| 160 | + |
| 161 | + return mesh |
0 commit comments