Skip to content

Commit 53d9565

Browse files
committed
propper volume integrals
1 parent b1f35f7 commit 53d9565

14 files changed

+439
-3318
lines changed

examples/ansys/bars_mesh.vtk

0 Bytes
Binary file not shown.

examples/ansys/meshing_tess/tesseract_api.py

Lines changed: 122 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import Any
22

3+
import jax
34
import jax.numpy as jnp
45

56
# import numpy as jnp
67
from jax.scipy.interpolate import RegularGridInterpolator
78
from pydantic import BaseModel, Field
8-
from scipy.interpolate import griddata
99
from tesseract_core.runtime import Array, Differentiable, Float32, Int32, ShapeDType
10+
from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths
1011

1112
#
1213
# Schemata
@@ -400,59 +401,135 @@ def generate_mesh(
400401
return pts, cells
401402

402403

403-
def apply(inputs: InputSchema) -> OutputSchema:
404-
"""Generate hexahedral mesh and interpolate field values onto cell centers.
404+
def compute_integral_volume(grid):
405+
"""Computes the integral volume (3D cumulative sum) of the grid.
406+
407+
Args:
408+
grid: grid values
409+
"""
410+
# We pad with one layer of zeros on the 'left' of every dimension.
411+
# This handles the boundary condition where a hex starts at index 0.
412+
# Cumulative sum along Depth, Height, and Width
413+
integral = jnp.cumsum(grid, axis=-1)
414+
integral = jnp.cumsum(integral, axis=-2)
415+
integral = jnp.cumsum(integral, axis=-3)
416+
417+
# Pad with zeros at the beginning of each spatial dimension
418+
padding = [(0, 0)] * (grid.ndim - 3) + [(1, 0), (1, 0), (1, 0)]
419+
integral_padded = jnp.pad(integral, padding, mode="constant", constant_values=0)
420+
421+
return integral_padded
422+
423+
424+
def apply_fn(inputs: dict) -> dict:
425+
"""Compute the compliance of the structure given a density field.
405426
406427
Args:
407-
inputs: InputSchema, inputs to the function.
428+
inputs: Dictionary containing input parameters and density field.
408429
409430
Returns:
410-
OutputSchema, outputs of the function.
431+
Dictionary containing the compliance of the structure.
411432
"""
412-
Lx = inputs.domain_size[0]
413-
Ly = inputs.domain_size[1]
414-
Lz = inputs.domain_size[2]
433+
Lx = inputs["domain_size"][0]
434+
Ly = inputs["domain_size"][1]
435+
Lz = inputs["domain_size"][2]
436+
437+
field_values = inputs["field_values"]
438+
max_points = inputs["max_points"]
439+
max_cells = inputs["max_cells"]
440+
sizing_field = inputs["sizing_field"]
441+
max_levels = inputs["max_subdivision_levels"]
442+
443+
# no stop grads
415444
pts, cells = generate_mesh(
416445
Lx=Lx,
417446
Ly=Ly,
418447
Lz=Lz,
419-
sizing_field=inputs.sizing_field,
420-
max_levels=inputs.max_subdivision_levels,
448+
sizing_field=sizing_field,
449+
max_levels=max_levels,
421450
)
422451

423-
pts_padded = jnp.zeros((inputs.max_points, 3), dtype=pts.dtype)
452+
print("Done building mesh")
453+
454+
pts_padded = jnp.zeros((max_points, 3), dtype=pts.dtype)
424455
pts_padded = pts_padded.at[: pts.shape[0], :].set(pts)
425-
cells_padded = jnp.zeros((inputs.max_cells, 8), dtype=cells.dtype)
456+
cells_padded = jnp.zeros((max_cells, 8), dtype=cells.dtype)
426457
cells_padded = cells_padded.at[: cells.shape[0], :].set(cells)
427458

428-
xs = jnp.linspace(-Lx / 2, Lx / 2, inputs.field_values.shape[0])
429-
ys = jnp.linspace(-Ly / 2, Ly / 2, inputs.field_values.shape[1])
430-
zs = jnp.linspace(-Lz / 2, Lz / 2, inputs.field_values.shape[2])
459+
def discretize(coord):
460+
coord = coord + jnp.array([Lx / 2, Ly / 2, Lz / 2])
461+
coord = coord / jnp.array([Lx, Ly, Lz])
462+
coord = coord * jnp.array([field_values.shape])
463+
return jnp.floor(coord).astype(jnp.int32)
431464

432-
interpolator = RegularGridInterpolator(
433-
(xs, ys, zs),
434-
inputs.field_values,
435-
method="linear",
436-
bounds_error=False,
437-
fill_value=-1,
465+
coords_disc = jax.vmap(discretize, in_axes=0)(pts)[:, 0]
466+
467+
integral = compute_integral_volume(field_values)
468+
469+
ind = coords_disc[cells[:, 0]]
470+
cell_000 = integral[ind[0], ind[1], ind[2]]
471+
472+
ind = coords_disc[cells[:, 1]]
473+
cell_100 = integral[ind[0], ind[1], ind[2]]
474+
475+
ind = coords_disc[cells[:, 2]]
476+
cell_110 = integral[ind[0], ind[1], ind[2]]
477+
478+
ind = coords_disc[cells[:, 3]]
479+
cell_010 = integral[ind[0], ind[1], ind[2]]
480+
481+
ind = coords_disc[cells[:, 4]]
482+
cell_001 = integral[ind[0], ind[1], ind[2]]
483+
484+
ind = coords_disc[cells[:, 5]]
485+
cell_101 = integral[ind[0], ind[1], ind[2]]
486+
487+
ind = coords_disc[cells[:, 6]]
488+
cell_111 = integral[ind[0], ind[1], ind[2]]
489+
490+
ind = coords_disc[cells[:, 7]]
491+
cell_011 = integral[ind[0], ind[1], ind[2]]
492+
493+
total_sum = (
494+
cell_111
495+
- cell_011
496+
- cell_101
497+
- cell_110
498+
+ cell_001
499+
+ cell_010
500+
+ cell_100
501+
- cell_000
438502
)
439503

440-
cell_centers = jnp.mean(pts[cells], axis=1)
504+
volume = jnp.prod(
505+
jnp.abs(coords_disc[cells[:, 6]] - coords_disc[cells[:, 0]]), axis=-1
506+
)
507+
volume = jnp.maximum(volume, 1.0)
441508

442-
cell_values = interpolator(cell_centers)
509+
cell_values = total_sum / volume
443510

444-
cell_values_padded = jnp.zeros((inputs.max_cells,), dtype=cell_values.dtype)
511+
cell_values_padded = jnp.zeros((max_cells,), dtype=jnp.float32)
445512
cell_values_padded = cell_values_padded.at[: cell_values.shape[0]].set(cell_values)
446513

447-
return OutputSchema(
448-
mesh=HexMesh(
449-
points=pts_padded.astype(jnp.float32),
450-
faces=cells_padded.astype(jnp.int32),
451-
n_points=pts.shape[0],
452-
n_faces=cells.shape[0],
453-
),
454-
mesh_cell_values=cell_values_padded,
455-
)
514+
return {
515+
"mesh": {
516+
"points": pts_padded.astype(jnp.float32),
517+
"faces": cells_padded.astype(jnp.int32),
518+
"n_points": pts.shape[0],
519+
"n_faces": cells.shape[0],
520+
},
521+
"mesh_cell_values": cell_values_padded.astype(jnp.float32),
522+
}
523+
524+
525+
#
526+
# Tesseract endpoints
527+
#
528+
529+
530+
def apply(inputs: InputSchema) -> OutputSchema:
531+
"""Compute the compliance of the structure given a density field."""
532+
return apply_fn(inputs.model_dump())
456533

457534

458535
def vector_jacobian_product(
@@ -461,55 +538,28 @@ def vector_jacobian_product(
461538
vjp_outputs: set[str],
462539
cotangent_vector: dict[str, Any],
463540
) -> dict[str, Any]:
464-
"""Compute vector-Jacobian product for the apply function.
465-
466-
Our cotangent gradient is defined on the cells centers
467-
we need to backpropagate it to the field values defined on the regular grid
468-
this can be done using interpolation
469-
We need to have the mesh cell center positions here, so instead of recomputing the mesh,
470-
lets use the cached mesh from the last forward pass
471-
print(generate_mesh.cache_info())
541+
"""Compute vector-Jacobian product for specified inputs and outputs.
472542
473543
Args:
474-
inputs: InputSchema, inputs to the apply function.
475-
vjp_inputs: set of input variable names for which to compute the VJP.
476-
vjp_outputs: set of output variable names for which the cotangent vector is provided.
477-
cotangent_vector: dict mapping output variable names to their cotangent vectors.
544+
inputs: InputSchema instance containing input parameters and density field.
545+
vjp_inputs: Set of input variable names for which to compute gradients.
546+
vjp_outputs: Set of output variable names with respect to which to compute gradients.
547+
cotangent_vector: Dictionary containing cotangent vectors for the specified outputs.
478548
479549
Returns:
480-
dict mapping input variable names to their VJP results.
550+
Dictionary containing the vector-Jacobian product for the specified inputs.
481551
"""
482552
assert vjp_inputs == {"field_values"}
483553
assert vjp_outputs == {"mesh_cell_values"}
484554

485-
Lx = inputs.domain_size[0]
486-
Ly = inputs.domain_size[1]
487-
Lz = inputs.domain_size[2]
555+
inputs = inputs.model_dump()
488556

489-
pts, cells = generate_mesh(
490-
Lx=Lx,
491-
Ly=Ly,
492-
Lz=Lz,
493-
sizing_field=inputs.sizing_field,
494-
max_levels=inputs.max_subdivision_levels,
557+
filtered_apply = filter_func(apply_fn, inputs, vjp_outputs)
558+
_, vjp_func = jax.vjp(
559+
filtered_apply, flatten_with_paths(inputs, include_paths=vjp_inputs)
495560
)
496-
497-
cell_centers = jnp.mean(pts[cells], axis=1)
498-
499-
xs = jnp.linspace(-Lx / 2, Lx / 2, inputs.field_values.shape[0])
500-
ys = jnp.linspace(-Ly / 2, Ly / 2, inputs.field_values.shape[1])
501-
zs = jnp.linspace(-Lz / 2, Lz / 2, inputs.field_values.shape[2])
502-
xs, ys, zs = jnp.meshgrid(xs, ys, zs, indexing="ij")
503-
504-
field_cotangent_vector = griddata(
505-
cell_centers,
506-
cotangent_vector["mesh_cell_values"][: cells.shape[0]],
507-
(xs, ys, zs),
508-
method="nearest",
509-
# fill_value=0.0,
510-
)
511-
512-
return {"field_values": jnp.array(field_cotangent_vector).astype(jnp.float32)}
561+
out = vjp_func(cotangent_vector)[0]
562+
return out
513563

514564

515565
def abstract_eval(abstract_inputs: InputSchema) -> dict[str, ShapeDType]:

examples/ansys/optim_bars.ipynb

Lines changed: 317 additions & 3246 deletions
Large diffs are not rendered by default.

examples/ansys/rho_optim_sum_2.gif

92.7 KB
Loading
-92.1 KB
Binary file not shown.
-88 KB
Binary file not shown.
-64.7 KB
Binary file not shown.
-82.1 KB
Binary file not shown.
-89.4 KB
Binary file not shown.
-86.9 KB
Binary file not shown.

0 commit comments

Comments
 (0)