11from typing import Any
22
3+ import jax
34import jax .numpy as jnp
45
56# import numpy as jnp
67from jax .scipy .interpolate import RegularGridInterpolator
78from pydantic import BaseModel , Field
8- from scipy .interpolate import griddata
99from tesseract_core .runtime import Array , Differentiable , Float32 , Int32 , ShapeDType
10+ from tesseract_core .runtime .tree_transforms import filter_func , flatten_with_paths
1011
1112#
1213# Schemata
@@ -400,59 +401,135 @@ def generate_mesh(
400401 return pts , cells
401402
402403
403- def apply (inputs : InputSchema ) -> OutputSchema :
404- """Generate hexahedral mesh and interpolate field values onto cell centers.
404+ def compute_integral_volume (grid ):
405+ """Computes the integral volume (3D cumulative sum) of the grid.
406+
407+ Args:
408+ grid: grid values
409+ """
410+ # We pad with one layer of zeros on the 'left' of every dimension.
411+ # This handles the boundary condition where a hex starts at index 0.
412+ # Cumulative sum along Depth, Height, and Width
413+ integral = jnp .cumsum (grid , axis = - 1 )
414+ integral = jnp .cumsum (integral , axis = - 2 )
415+ integral = jnp .cumsum (integral , axis = - 3 )
416+
417+ # Pad with zeros at the beginning of each spatial dimension
418+ padding = [(0 , 0 )] * (grid .ndim - 3 ) + [(1 , 0 ), (1 , 0 ), (1 , 0 )]
419+ integral_padded = jnp .pad (integral , padding , mode = "constant" , constant_values = 0 )
420+
421+ return integral_padded
422+
423+
424+ def apply_fn (inputs : dict ) -> dict :
425+ """Compute the compliance of the structure given a density field.
405426
406427 Args:
407- inputs: InputSchema, inputs to the function .
428+ inputs: Dictionary containing input parameters and density field .
408429
409430 Returns:
410- OutputSchema, outputs of the function .
431+ Dictionary containing the compliance of the structure .
411432 """
412- Lx = inputs .domain_size [0 ]
413- Ly = inputs .domain_size [1 ]
414- Lz = inputs .domain_size [2 ]
433+ Lx = inputs ["domain_size" ][0 ]
434+ Ly = inputs ["domain_size" ][1 ]
435+ Lz = inputs ["domain_size" ][2 ]
436+
437+ field_values = inputs ["field_values" ]
438+ max_points = inputs ["max_points" ]
439+ max_cells = inputs ["max_cells" ]
440+ sizing_field = inputs ["sizing_field" ]
441+ max_levels = inputs ["max_subdivision_levels" ]
442+
443+ # no stop grads
415444 pts , cells = generate_mesh (
416445 Lx = Lx ,
417446 Ly = Ly ,
418447 Lz = Lz ,
419- sizing_field = inputs . sizing_field ,
420- max_levels = inputs . max_subdivision_levels ,
448+ sizing_field = sizing_field ,
449+ max_levels = max_levels ,
421450 )
422451
423- pts_padded = jnp .zeros ((inputs .max_points , 3 ), dtype = pts .dtype )
452+ print ("Done building mesh" )
453+
454+ pts_padded = jnp .zeros ((max_points , 3 ), dtype = pts .dtype )
424455 pts_padded = pts_padded .at [: pts .shape [0 ], :].set (pts )
425- cells_padded = jnp .zeros ((inputs . max_cells , 8 ), dtype = cells .dtype )
456+ cells_padded = jnp .zeros ((max_cells , 8 ), dtype = cells .dtype )
426457 cells_padded = cells_padded .at [: cells .shape [0 ], :].set (cells )
427458
428- xs = jnp .linspace (- Lx / 2 , Lx / 2 , inputs .field_values .shape [0 ])
429- ys = jnp .linspace (- Ly / 2 , Ly / 2 , inputs .field_values .shape [1 ])
430- zs = jnp .linspace (- Lz / 2 , Lz / 2 , inputs .field_values .shape [2 ])
459+ def discretize (coord ):
460+ coord = coord + jnp .array ([Lx / 2 , Ly / 2 , Lz / 2 ])
461+ coord = coord / jnp .array ([Lx , Ly , Lz ])
462+ coord = coord * jnp .array ([field_values .shape ])
463+ return jnp .floor (coord ).astype (jnp .int32 )
431464
432- interpolator = RegularGridInterpolator (
433- (xs , ys , zs ),
434- inputs .field_values ,
435- method = "linear" ,
436- bounds_error = False ,
437- fill_value = - 1 ,
465+ coords_disc = jax .vmap (discretize , in_axes = 0 )(pts )[:, 0 ]
466+
467+ integral = compute_integral_volume (field_values )
468+
469+ ind = coords_disc [cells [:, 0 ]]
470+ cell_000 = integral [ind [0 ], ind [1 ], ind [2 ]]
471+
472+ ind = coords_disc [cells [:, 1 ]]
473+ cell_100 = integral [ind [0 ], ind [1 ], ind [2 ]]
474+
475+ ind = coords_disc [cells [:, 2 ]]
476+ cell_110 = integral [ind [0 ], ind [1 ], ind [2 ]]
477+
478+ ind = coords_disc [cells [:, 3 ]]
479+ cell_010 = integral [ind [0 ], ind [1 ], ind [2 ]]
480+
481+ ind = coords_disc [cells [:, 4 ]]
482+ cell_001 = integral [ind [0 ], ind [1 ], ind [2 ]]
483+
484+ ind = coords_disc [cells [:, 5 ]]
485+ cell_101 = integral [ind [0 ], ind [1 ], ind [2 ]]
486+
487+ ind = coords_disc [cells [:, 6 ]]
488+ cell_111 = integral [ind [0 ], ind [1 ], ind [2 ]]
489+
490+ ind = coords_disc [cells [:, 7 ]]
491+ cell_011 = integral [ind [0 ], ind [1 ], ind [2 ]]
492+
493+ total_sum = (
494+ cell_111
495+ - cell_011
496+ - cell_101
497+ - cell_110
498+ + cell_001
499+ + cell_010
500+ + cell_100
501+ - cell_000
438502 )
439503
440- cell_centers = jnp .mean (pts [cells ], axis = 1 )
504+ volume = jnp .prod (
505+ jnp .abs (coords_disc [cells [:, 6 ]] - coords_disc [cells [:, 0 ]]), axis = - 1
506+ )
507+ volume = jnp .maximum (volume , 1.0 )
441508
442- cell_values = interpolator ( cell_centers )
509+ cell_values = total_sum / volume
443510
444- cell_values_padded = jnp .zeros ((inputs . max_cells ,), dtype = cell_values . dtype )
511+ cell_values_padded = jnp .zeros ((max_cells ,), dtype = jnp . float32 )
445512 cell_values_padded = cell_values_padded .at [: cell_values .shape [0 ]].set (cell_values )
446513
447- return OutputSchema (
448- mesh = HexMesh (
449- points = pts_padded .astype (jnp .float32 ),
450- faces = cells_padded .astype (jnp .int32 ),
451- n_points = pts .shape [0 ],
452- n_faces = cells .shape [0 ],
453- ),
454- mesh_cell_values = cell_values_padded ,
455- )
514+ return {
515+ "mesh" : {
516+ "points" : pts_padded .astype (jnp .float32 ),
517+ "faces" : cells_padded .astype (jnp .int32 ),
518+ "n_points" : pts .shape [0 ],
519+ "n_faces" : cells .shape [0 ],
520+ },
521+ "mesh_cell_values" : cell_values_padded .astype (jnp .float32 ),
522+ }
523+
524+
525+ #
526+ # Tesseract endpoints
527+ #
528+
529+
530+ def apply (inputs : InputSchema ) -> OutputSchema :
531+ """Compute the compliance of the structure given a density field."""
532+ return apply_fn (inputs .model_dump ())
456533
457534
458535def vector_jacobian_product (
@@ -461,55 +538,28 @@ def vector_jacobian_product(
461538 vjp_outputs : set [str ],
462539 cotangent_vector : dict [str , Any ],
463540) -> dict [str , Any ]:
464- """Compute vector-Jacobian product for the apply function.
465-
466- Our cotangent gradient is defined on the cells centers
467- we need to backpropagate it to the field values defined on the regular grid
468- this can be done using interpolation
469- We need to have the mesh cell center positions here, so instead of recomputing the mesh,
470- lets use the cached mesh from the last forward pass
471- print(generate_mesh.cache_info())
541+ """Compute vector-Jacobian product for specified inputs and outputs.
472542
473543 Args:
474- inputs: InputSchema, inputs to the apply function .
475- vjp_inputs: set of input variable names for which to compute the VJP .
476- vjp_outputs: set of output variable names for which the cotangent vector is provided .
477- cotangent_vector: dict mapping output variable names to their cotangent vectors .
544+ inputs: InputSchema instance containing input parameters and density field .
545+ vjp_inputs: Set of input variable names for which to compute gradients .
546+ vjp_outputs: Set of output variable names with respect to which to compute gradients .
547+ cotangent_vector: Dictionary containing cotangent vectors for the specified outputs .
478548
479549 Returns:
480- dict mapping input variable names to their VJP results .
550+ Dictionary containing the vector-Jacobian product for the specified inputs .
481551 """
482552 assert vjp_inputs == {"field_values" }
483553 assert vjp_outputs == {"mesh_cell_values" }
484554
485- Lx = inputs .domain_size [0 ]
486- Ly = inputs .domain_size [1 ]
487- Lz = inputs .domain_size [2 ]
555+ inputs = inputs .model_dump ()
488556
489- pts , cells = generate_mesh (
490- Lx = Lx ,
491- Ly = Ly ,
492- Lz = Lz ,
493- sizing_field = inputs .sizing_field ,
494- max_levels = inputs .max_subdivision_levels ,
557+ filtered_apply = filter_func (apply_fn , inputs , vjp_outputs )
558+ _ , vjp_func = jax .vjp (
559+ filtered_apply , flatten_with_paths (inputs , include_paths = vjp_inputs )
495560 )
496-
497- cell_centers = jnp .mean (pts [cells ], axis = 1 )
498-
499- xs = jnp .linspace (- Lx / 2 , Lx / 2 , inputs .field_values .shape [0 ])
500- ys = jnp .linspace (- Ly / 2 , Ly / 2 , inputs .field_values .shape [1 ])
501- zs = jnp .linspace (- Lz / 2 , Lz / 2 , inputs .field_values .shape [2 ])
502- xs , ys , zs = jnp .meshgrid (xs , ys , zs , indexing = "ij" )
503-
504- field_cotangent_vector = griddata (
505- cell_centers ,
506- cotangent_vector ["mesh_cell_values" ][: cells .shape [0 ]],
507- (xs , ys , zs ),
508- method = "nearest" ,
509- # fill_value=0.0,
510- )
511-
512- return {"field_values" : jnp .array (field_cotangent_vector ).astype (jnp .float32 )}
561+ out = vjp_func (cotangent_vector )[0 ]
562+ return out
513563
514564
515565def abstract_eval (abstract_inputs : InputSchema ) -> dict [str , ShapeDType ]:
0 commit comments