|
2 | 2 | from typing import TypeVar |
3 | 3 |
|
4 | 4 | import jax |
| 5 | +import jax.numpy as jnp |
5 | 6 | import matplotlib.pyplot as plt |
| 7 | +import mmapy |
6 | 8 | import numpy as np |
7 | 9 | import pyvista as pv |
8 | 10 | from mpl_toolkits.axes_grid1 import make_axes_locatable |
@@ -159,3 +161,186 @@ def hex_to_pyvista( |
159 | 161 | mesh.cell_data[name] = data |
160 | 162 |
|
161 | 163 | return mesh |
| 164 | + |
| 165 | + |
| 166 | +def hex_grid( |
| 167 | + Lx: float, Ly: float, Lz: float, Nx: int, Ny: int, Nz: int |
| 168 | +) -> tuple[jnp.ndarray, jnp.ndarray]: |
| 169 | + """Creates a hex mesh with Nx * Ny * Nz points. |
| 170 | +
|
| 171 | + This is (Nx-1) * (Ny-1) * (Nz-1) cells |
| 172 | + """ |
| 173 | + xs = jnp.linspace(-Lx / 2, Lx / 2, Nx) |
| 174 | + ys = jnp.linspace(-Ly / 2, Ly / 2, Ny) |
| 175 | + zs = jnp.linspace(-Lz / 2, Lz / 2, Nz) |
| 176 | + |
| 177 | + xs, ys, zs = jnp.meshgrid(xs, ys, zs, indexing="ij") |
| 178 | + |
| 179 | + pts = jnp.stack((xs, ys, zs), -1) |
| 180 | + |
| 181 | + points_inds = jnp.arange(Nx * Ny * Nz) |
| 182 | + points_inds_xyz = points_inds.reshape(Nx, Ny, Nz) |
| 183 | + inds1 = points_inds_xyz[:-1, :-1, :-1] |
| 184 | + inds2 = points_inds_xyz[1:, :-1, :-1] |
| 185 | + inds3 = points_inds_xyz[1:, 1:, :-1] |
| 186 | + inds4 = points_inds_xyz[:-1, 1:, :-1] |
| 187 | + inds5 = points_inds_xyz[:-1, :-1, 1:] |
| 188 | + inds6 = points_inds_xyz[1:, :-1, 1:] |
| 189 | + inds7 = points_inds_xyz[1:, 1:, 1:] |
| 190 | + inds8 = points_inds_xyz[:-1, 1:, 1:] |
| 191 | + |
| 192 | + cells = jnp.stack( |
| 193 | + (inds1, inds2, inds3, inds4, inds5, inds6, inds7, inds8), axis=-1 |
| 194 | + ).reshape(-1, 8) |
| 195 | + |
| 196 | + return pts.reshape(-1, 3), cells |
| 197 | + |
| 198 | + |
| 199 | +class MMAOptimizer: |
| 200 | + """A wrapper for the MMA optimizer from mmapy. |
| 201 | +
|
| 202 | + Source is github.com/arjendeetman/GCMMA-MMA-Python. |
| 203 | + mmapy is a pretty barebones implementation of MMA in python. It should work for now. |
| 204 | + Alternatives to consider: |
| 205 | + - github.com/LLNL/pyMMAopt |
| 206 | + - pyopt.org/reference/optimizers.mma.html. |
| 207 | +
|
| 208 | + """ |
| 209 | + |
| 210 | + def __init__( |
| 211 | + self, |
| 212 | + x_init: jax.typing.ArrayLike, |
| 213 | + x_min: jax.typing.ArrayLike, |
| 214 | + x_max: jax.typing.ArrayLike, |
| 215 | + num_constraints: jax.typing.ArrayLike, |
| 216 | + constraint_scale: jax.typing.ArrayLike = 1000.0, |
| 217 | + x_update_limit: jax.typing.ArrayLike = 0.1, |
| 218 | + ) -> None: |
| 219 | + self.n = x_init.shape[0] |
| 220 | + self.m = num_constraints |
| 221 | + self.__check_input_sizes(x_init, x_min, x_max) |
| 222 | + |
| 223 | + # follow the original MMA variable names... |
| 224 | + self.asyinit = 0.5 |
| 225 | + self.asyincr = 1.2 |
| 226 | + self.asydecr = 0.7 |
| 227 | + self.objective_scale: float = 100.0 |
| 228 | + self.objective_scale_factor: float = 1.0 |
| 229 | + |
| 230 | + self.eeen = np.ones((self.n, 1)) |
| 231 | + self.eeem = np.ones((self.m, 1)) |
| 232 | + self.zeron = np.zeros((self.n, 1)) |
| 233 | + self.zerom = np.zeros((self.m, 1)) |
| 234 | + |
| 235 | + self.xval = x_init |
| 236 | + self.xold1 = self.xval.copy() |
| 237 | + self.xold2 = self.xval.copy() |
| 238 | + self.x_min = x_min |
| 239 | + self.x_max = x_max |
| 240 | + self.low = self.x_min.copy() |
| 241 | + self.upp = self.x_max.copy() |
| 242 | + self.c = constraint_scale + self.zerom.copy() |
| 243 | + self.d = self.zerom.copy() |
| 244 | + self.a0 = 1 |
| 245 | + self.a = self.zerom.copy() |
| 246 | + self.move = x_update_limit |
| 247 | + |
| 248 | + def calculate_next_x( |
| 249 | + self, |
| 250 | + objective_value: jax.typing.ArrayLike, |
| 251 | + objective_gradient: jax.typing.ArrayLike, |
| 252 | + constraint_values: jax.typing.ArrayLike, |
| 253 | + constraint_gradients: jax.typing.ArrayLike, |
| 254 | + iteration: int, |
| 255 | + x: jax.typing.ArrayLike, |
| 256 | + x_min: jax.typing.ArrayLike = None, |
| 257 | + x_max: jax.typing.ArrayLike = None, |
| 258 | + ) -> jax.typing.ArrayLike: |
| 259 | + if iteration < 1: |
| 260 | + raise Exception("The MMA problem expects an iteration count >= 1.") |
| 261 | + |
| 262 | + # The MMA problem works best with an objective scaled around [1, 100] |
| 263 | + if iteration == 1: |
| 264 | + self.objective_scale_factor = np.abs(self.objective_scale / objective_value) |
| 265 | + objective_value *= self.objective_scale_factor |
| 266 | + objective_gradient = ( |
| 267 | + np.asarray(objective_gradient) * self.objective_scale_factor |
| 268 | + ) |
| 269 | + |
| 270 | + # the bounds dont necessarily change every iteration |
| 271 | + if x_min is None: |
| 272 | + x_min = self.x_min |
| 273 | + if x_max is None: |
| 274 | + x_max = self.x_max |
| 275 | + |
| 276 | + self.__check_input_sizes( |
| 277 | + x, |
| 278 | + x_min, |
| 279 | + x_max, |
| 280 | + objective_gradient=objective_gradient, |
| 281 | + constraint_values=constraint_values, |
| 282 | + constraint_gradients=constraint_gradients, |
| 283 | + ) |
| 284 | + |
| 285 | + # calculate the next iteration of x |
| 286 | + xmma, _ymma, _zmma, _lam, _xsi, _eta, _mu, _zet, _s, low, upp = mmapy.mmasub( |
| 287 | + self.m, |
| 288 | + self.n, |
| 289 | + iteration, |
| 290 | + x, |
| 291 | + x_min, |
| 292 | + x_max, |
| 293 | + self.xold1, |
| 294 | + self.xold2, |
| 295 | + objective_value, |
| 296 | + objective_gradient, |
| 297 | + constraint_values, |
| 298 | + constraint_gradients, |
| 299 | + self.low, |
| 300 | + self.upp, |
| 301 | + self.a0, |
| 302 | + self.a, |
| 303 | + self.c, |
| 304 | + self.d, |
| 305 | + move=self.move, |
| 306 | + asyinit=self.asyinit, |
| 307 | + asyincr=self.asyincr, |
| 308 | + asydecr=self.asydecr, |
| 309 | + ) |
| 310 | + # update internal copies for mma |
| 311 | + self.xold2 = self.xold1.copy() |
| 312 | + self.xold1 = self.xval.copy() |
| 313 | + self.xval = xmma.copy() |
| 314 | + self.low = low |
| 315 | + self.upp = upp |
| 316 | + |
| 317 | + return xmma |
| 318 | + |
| 319 | + def __check_input_sizes( |
| 320 | + self, |
| 321 | + x: jax.typing.ArrayLike, |
| 322 | + x_min: jax.typing.ArrayLike, |
| 323 | + x_max: jax.typing.ArrayLike, |
| 324 | + objective_gradient: jax.typing.ArrayLike = None, |
| 325 | + constraint_values: jax.typing.ArrayLike = None, |
| 326 | + constraint_gradients: jax.typing.ArrayLike = None, |
| 327 | + ) -> None: |
| 328 | + def check_shape(shape: tuple, expected_shape: tuple, name: str) -> None: |
| 329 | + if (len(shape) == 1) or ( |
| 330 | + shape[0] != expected_shape[0] or shape[1] != expected_shape[1] |
| 331 | + ): |
| 332 | + raise TypeError( |
| 333 | + f"MMAError: The '{name}' was expected to have shape {expected_shape} but has shape {shape}." |
| 334 | + ) |
| 335 | + |
| 336 | + check_shape(x.shape, (self.n, 1), "parameter vector") |
| 337 | + check_shape(x_min.shape, (self.n, 1), "parameter minimum bound vector") |
| 338 | + check_shape(x_max.shape, (self.n, 1), "parameter maximum bound vector") |
| 339 | + if objective_gradient is not None: |
| 340 | + check_shape(objective_gradient.shape, (self.n, 1), "objective gradient") |
| 341 | + if constraint_values is not None: |
| 342 | + check_shape(constraint_values.shape, (self.m, 1), "constraint values") |
| 343 | + if constraint_gradients is not None: |
| 344 | + check_shape( |
| 345 | + constraint_gradients.shape, (self.m, self.n), "constraint gradients" |
| 346 | + ) |
0 commit comments