A differentiable finite element analysis (FEA) solver for structural optimization, enabled by JAX.
Developed by Gaoyuan Wu @ Princeton.
- Automatic differentiation (AD): an easy and accurate way for gradient evaluation. The implementation of AD avoids deriving derivatives manually or trauncation errors from numerical differentiation. AD is handy for sensitivity analysis of gradient-based optimization and training of neural networks (NN) with differentiable physics.
- Acclerated linear algebra (XLA) and just-in-time compilation: these features in JAX boost the gradient evaluation
- Hardware acceleration: run on GPUs and TPUs for faster experience
- Support beam-column elements and MITC-4 quadrilateral shell elements
- Shape optimization, size optimization and topology optimization
- Seamless integration with machine learning (ML) libraries
An overview of the package structure of JaxSSO is shown in the following figure.
The element.py module is related to underlying mechanics and formulations of different structural elements, such as beam-columns and MITC4 shells.
The model.py module creates a finite element model to be analyzed. Users use this module to add structural elements, specify boundary conditions, and impose loads.
The assemblemodel.py module assembles the linear system equations
The solver.py module conducts forward analysis and solves for the solution
The SSO_model.py module is for backward propogation/optimization. Users can specify various parameters and objective function. Derivatives are then obtained in an automated manner thanks to AD.
Install it with pip: pip install JaxSSO
JaxSSO is written in Python and requires:
- numpy >= 1.22.0.
- JAX: "JAX is Autograd and XLA, brought together for high-performance machine learning research." Please refer to this link for the installation of JAX.
- scipy.
Optional:
- Nlopt: Nlopt is a library for nonlinear optimization. It has Python interface, which is implemented herein. Refer to this link for the installation of Nlopt. Alternatively, you can use
pip install nlopt, please refer to nlopt-python. - Flax: neural network library based on JAX. JAXSSO can be integrated with
flax, please seeExamples/Neural_Network_Topo_Shape.ipynb - Optax: optimization library based on JAX, can be used to train neural networks.
The project provides you with interactive examples with Google Colab for quick start. No installation locally is required.
- Shape optimization of grid shell: geometry from Favilli et al. 2024
Please star, share our project with others and/or cite us if you find our work interesting and helpful.
We have a new manuscript.
Our previous work can be seen in this paper. Cite our previous work using:
@article{wu2023framework,
title={A framework for structural shape optimization based on automatic differentiation, the adjoint method and accelerated linear algebra},
author={Wu, Gaoyuan},
journal={Structural and Multidisciplinary Optimization},
volume={66},
url = {https://doi.org/10.1007/s00158-023-03601-0},
doi = {10.1007/s00158-023-03601-0},
pages={151},
year={2023},
publisher={Springer}
}





