This guide covers development setup, code organization, and contribution guidelines.
If you're new to VAJAX, start with these resources:
- Getting Started - Install and run your first simulation
- Architecture Overview - How the pieces fit together
- Supported Devices - What device models are available
VAJAX is a SPICE-class circuit simulator built on JAX. If you're unfamiliar with circuit simulation, here are the core concepts:
-
Modified Nodal Analysis (MNA): Formulates circuit equations as
G*V = Iwhere G is a conductance matrix, V is node voltages, and I is current sources. Each device "stamps" its contributions into G and I. See MNA on Wikipedia. -
Newton-Raphson iteration: Nonlinear circuits require iterative solving. At each step, we linearize the circuit around the current solution and solve
J * delta = -f(V)where J is the Jacobian and f is the residual (KCL violations). -
Transient analysis: Time-domain simulation using numerical integration (trapezoidal rule or Gear's method). Capacitors and inductors introduce time derivatives that are discretized into equivalent conductances.
-
Verilog-A / OpenVAF: Device models (transistors, diodes, etc.) are written in Verilog-A, compiled by OpenVAF to machine code, then translated to JAX functions. This gives us production-quality device models with automatic differentiation.
- Python 3.11-3.13
- uv package manager
- Rust toolchain (for building openvaf-py)
- LLVM 18 (for OpenVAF, macOS only)
# Clone and enter directory
git clone <repo-url>
cd vajax
# Install dependencies with uv (--extra test includes pytest and test utilities)
uv sync --extra test
# Run tests to verify setup
JAX_PLATFORMS=cpu uv run pytest tests/ -vOpenVAF requires LLVM 18 on macOS:
# Install LLVM via Homebrew
brew install llvm@18
# Build OpenVAF
./scripts/build_openvaf.sh
# Build openvaf-py
cd openvaf-py
LLVM_SYS_181_PREFIX=/opt/homebrew/opt/llvm@18 uv run maturin develop# Install with CUDA support
uv sync --extra cuda12
# Test GPU availability
JAX_PLATFORMS=cuda uv run python -c "import jax; print(jax.devices())"vajax/
├── vajax/ # Main library
│ ├── devices/ # Device models
│ ├── analysis/ # Circuit solvers
│ ├── netlist/ # Netlist parsing
│ └── benchmarks/ # Benchmark infrastructure
├── openvaf-py/ # OpenVAF Python bindings (Rust)
├── tests/ # Test suite
├── scripts/ # Build and profiling scripts
├── docs/ # Architecture documentation
└── vendor/ # External dependencies (VACASK)
| Module | Purpose | Key Files |
|---|---|---|
vajax.devices |
Device models | verilog_a.py, vsource.py |
vajax.analysis |
Solvers | engine.py, mna.py, dc_operating_point.py, solver.py, transient/ |
vajax.netlist |
Parsing | parser.py, circuit.py |
vajax.benchmarks |
Benchmarks | registry.py, runner.py |
# All tests (force CPU for reproducibility)
JAX_PLATFORMS=cpu uv run pytest tests/ -v
# Specific test file
JAX_PLATFORMS=cpu uv run pytest tests/test_vacask_suite.py -v
# Single test
JAX_PLATFORMS=cpu uv run pytest tests/test_resistor.py::test_ohms_law -v
# With coverage
JAX_PLATFORMS=cpu uv run pytest tests/ --cov=vajax --cov-report=html
# openvaf-py tests (separate environment)
cd openvaf-py && JAX_PLATFORMS=cpu ../.venv/bin/python -m pytest tests/ -vtests/test_*.py- Main library teststests/test_vacask_*.py- VACASK benchmark integration testsopenvaf-py/tests/- OpenVAF translator tests
# Format with ruff
uv run ruff format vajax tests
# Lint
uv run ruff check vajax tests
# Type check
uv run pyright vajaxConfiguration is in pyproject.toml:
- Line length: 100 characters
- Target: Python 3.11
- Lints: E, F, I (isort), N (naming), W
- Pure functions preferred: Device models should be pure JAX functions without side effects
- Type hints: Use type hints for public APIs
- Docstrings: Document public functions with parameter descriptions
- No trailing whitespace: Configure editor to trim trailing whitespace
- Snake case: Use
snake_casefor functions and variables
All devices except voltage/current sources are compiled from Verilog-A via
OpenVAF and wrapped by VerilogADevice (vajax/devices/verilog_a.py).
Device instances are grouped by type and evaluated in parallel using jax.vmap
for GPU efficiency. The engine handles MNA stamping automatically.
The simulator uses MNA to form the circuit equations:
G*V + C*dV/dt = I
Where:
- G: Conductance matrix (from device stamps)
- V: Node voltages (unknowns)
- C: Capacitance matrix (for transient)
- I: Current vector (from sources)
DC and transient solvers use Newton-Raphson:
V_new = V - J^(-1) * f(V)
Where:
- f(V): Residual (KCL violations)
- J: Jacobian (df/dV, computed via JAX autodiff)
For performance, devices are grouped by type and evaluated in parallel:
# Instead of:
for device in mosfets:
I = device.evaluate(V)
# We do:
I_all = jax.vmap(mosfet_evaluate)(V_batched, params_batched)All devices (resistors, capacitors, diodes, MOSFETs, etc.) are routed through
OpenVAF Verilog-A compilation via VerilogADevice in vajax/devices/verilog_a.py.
The only exceptions are voltage and current sources, which are handled separately
in vajax/devices/vsource.py.
To add a new device:
- Write or obtain a Verilog-A model (
.vafile) - Compile it with OpenVAF to produce an
.osdimodule - Reference it in a VACASK
.simnetlist withload "your_model.osdi" - The engine will automatically wrap it via
VerilogADevicewith batchedjax.vmapevaluation for GPU efficiency
There is no need to write Python device code — OpenVAF handles the compilation from Verilog-A to JAX-compatible functions.
To add tests, create tests/test_your_device.py using the public
CircuitEngine API with a .sim netlist that exercises the device.
- Create
vajax/analysis/your_analysis.py - Use
DeviceInfoandDeviceTypefromvajax/analysis/mna.pyfor device management - For transient-style analyses, subclass
TransientStrategyABC fromvajax/analysis/transient/base.py - Build on existing patterns from
dc_operating_point.pyor thetransient/package - Export from
vajax/analysis/__init__.py
# Print intermediate values (breaks JIT)
jax.debug.print("Value: {x}", x=my_array)
# Check for NaN
jax.config.update("jax_debug_nans", True)
# Disable JIT for debugging
with jax.disable_jit():
result = my_function(inputs)-
JIT compilation errors: Usually from Python conditionals on traced values
- Use
jax.lax.condinstead ofif - Use
jax.lax.selectfor element-wise conditionals
- Use
-
NaN in results: Check for:
- Division by zero (add small epsilon)
- Log of non-positive values
- Invalid device parameters
-
Convergence failures: Try:
- Homotopy chain:
run_homotopy_chain()fromvajax.analysis.homotopy - GMIN stepping:
gmin_stepping()with mode="gdev" or "gshunt" - Source stepping:
source_stepping() - Increased iteration limit
- Relaxed tolerances
- Homotopy chain:
# CPU profiling
JAX_PLATFORMS=cpu uv run python scripts/profile_gpu.py --benchmark ring
# GPU profiling with Perfetto traces
uv run python scripts/profile_gpu_cloudrun.py --benchmark ring --timesteps 50- Create a feature branch from
main - Make your changes with clear commit messages
- Ensure tests pass:
JAX_PLATFORMS=cpu uv run pytest tests/ -v - Run linter:
uv run ruff check vajax tests - Push and create PR
- Wait for CI checks to pass
<type>: <short description>
<optional longer description>
Types:
- feat: New feature
- fix: Bug fix
- refactor: Code restructuring
- test: Adding tests
- docs: Documentation
- perf: Performance improvement
- JAX Documentation
- OpenVAF GitHub
- SPICE Theory
docs/folder for architecture details