Skip to content

Fast Amortized Neural Symbolic Regression with Transformers and SimpliPy

License

Notifications You must be signed in to change notification settings

psaegert/flash-ansr

Repository files navigation

⚡Flash-ANSR:
Fast Amortized Neural Symbolic Regression

PyPI version PyPI license Documentation Status

pytest quality checks CodeQL Advanced

Papers

  • WIP

Usage

pip install flash-ansr
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Import flash_ansr
from flash_ansr import (
  FlashANSR,
  SoftmaxSamplingConfig,
  install_model,
  get_path,
)

# Select a model from Hugging Face
# https://huggingface.co/models?search=flash-ansr-v23.0
MODEL = "psaegert/flash-ansr-v23.0-120M"

# Download the latest snapshot of the model
# By default, the model is downloaded to the directory `./models/` in the package root
install_model(MODEL)

# Load the model
model = FlashANSR.load(
  directory=get_path('models', MODEL),
  generation_config=SoftmaxSamplingConfig(choices=32),  # or BeamSearchConfig / MCTSGenerationConfig
  n_restarts=8,
).to(device)

# Define data
X = ...
y = ...

# Fit the model to the data
model.fit(X, y, verbose=True)

# Show the best expression
print(model.get_expression())

# Predict with the best expression
y_pred = model.predict(X)

Explore more in the Demo Notebook.

Overview

Training

⚡ANSR Training on Fully Procedurally Generated Data Inspired by NeSymReS (Biggio et al. 2021)

Architecture

FlashANSR Architecture. The model consists of an upgraded version of the Set Transformer (Lee et al. 2019) encoder, and a Pre-Norm Transformer decoder (Vaswani et al. 2017, Xiong et al. 2020) as a generative model over symbolic expressions.

Results

Coming soon

Citation

@mastersthesis{flash-ansr2024-thesis,
  author  = {Paul Saegert},
  title   = {Flash Amortized Neural Symbolic Regression},
  school  = {Heidelberg University},
  year    = {2025},
  url     = {https://github.com/psaegert/flash-ansr-thesis}
}
@software{flash-ansr2024,
  author  = {Paul Saegert},
  title   = {Flash Amortized Neural Symbolic Regression},
  year    = {2024},
  publisher   = {GitHub},
  version = {0.4.5},
  url     = {https://github.com/psaegert/flash-ansr}
}