|
| 1 | +# ONNX Runtime GenAI - AI Coding Agent Instructions |
| 2 | + |
| 3 | +## Architecture Overview |
| 4 | + |
| 5 | +This is **ONNX Runtime GenAI**, a high-performance inference library for generative AI models. The codebase implements the complete generative AI loop including preprocessing, ONNX Runtime inference, logits processing, search/sampling, and KV cache management. |
| 6 | + |
| 7 | +### Core Components |
| 8 | + |
| 9 | +- **`src/models/`** - Model implementations with support for LLMs, VLMs (Vision), ALMs (Audio), and Pipeline models |
| 10 | +- **`src/engine/`** - Request batching engine for concurrent model execution with dynamic scheduling |
| 11 | +- **`src/generators.h`** - Central generator logic coordinating the full inference pipeline |
| 12 | +- **`src/ort_genai.h`** - Zero-cost C++ wrapper around the C API for automatic resource management |
| 13 | +- **Language bindings**: Python (`src/python/`), C# (`src/csharp/`), Java (`src/java/`), Objective-C (`src/objectivec/`) |
| 14 | + |
| 15 | +### Key Abstractions |
| 16 | + |
| 17 | +```cpp |
| 18 | +// Core inference flow: Model -> Generator -> Tokenizer |
| 19 | +auto model = OgaModel::Create("phi-2"); |
| 20 | +auto tokenizer = OgaTokenizer::Create(*model); |
| 21 | +auto generator = OgaGenerator::Create(*model, params); |
| 22 | +``` |
| 23 | + |
| 24 | +The `State` class hierarchy in `src/models/model.h` handles device-specific execution, while the `Engine` class in `src/engine/` manages request batching and scheduling. |
| 25 | + |
| 26 | +## Build System & Development Workflow |
| 27 | + |
| 28 | +### Primary Build Commands |
| 29 | + |
| 30 | +```bash |
| 31 | +# Cross-platform Python build script (preferred) |
| 32 | +python build.py --config Release --use_cuda --build_java --enable_tests |
| 33 | + |
| 34 | +# Platform-specific scripts |
| 35 | +build.bat # Windows batch |
| 36 | +build.sh # Linux/Mac shell |
| 37 | +``` |
| 38 | + |
| 39 | +### Key Build Options (cmake/options.cmake) |
| 40 | + |
| 41 | +- `USE_CUDA/USE_DML/USE_ROCM` - Hardware acceleration backends |
| 42 | +- `USE_WINML` - Windows ML integration requiring `WINML_SDK_VERSION` parameter |
| 43 | +- `ENABLE_JAVA/ENABLE_PYTHON` - Language binding compilation |
| 44 | +- `USE_GUIDANCE` - Constrained generation support |
| 45 | + |
| 46 | +### WinML Build Pattern |
| 47 | + |
| 48 | +WinML builds require explicit SDK version specification: |
| 49 | + |
| 50 | +```bash |
| 51 | +# WinML build - WINML_SDK_VERSION is mandatory |
| 52 | +python build.py --use_winml -DWINML_SDK_VERSION=1.8.2084 |
| 53 | +``` |
| 54 | + |
| 55 | +WinML integration downloads `Microsoft.WindowsAppSDK.ML` via NuGet and copies headers/libs to a local `ort/` directory. |
| 56 | + |
| 57 | +### Testing |
| 58 | + |
| 59 | +```bash |
| 60 | +# Python tests with test models |
| 61 | +python -m pytest -sv test_onnxruntime_genai_api.py -k "test_name" --test_models ..\test_models |
| 62 | + |
| 63 | +# C++ unit tests via CMake/CTest |
| 64 | +ctest --build-config Release --output-on-failure |
| 65 | +``` |
| 66 | + |
| 67 | +## Code Patterns & Conventions |
| 68 | + |
| 69 | +### Device Interface Pattern |
| 70 | + |
| 71 | +Each hardware backend implements `DeviceInterface` (defined in `src/smartptrs.h`): |
| 72 | + |
| 73 | +```cpp |
| 74 | +struct CudaInterface : DeviceInterface { |
| 75 | + std::unique_ptr<DeviceBuffer> Allocate(size_t size) override; |
| 76 | + void CopyToDevice(DeviceSpan<T> dst, std::span<const T> src) override; |
| 77 | +}; |
| 78 | +``` |
| 79 | +
|
| 80 | +### Model State Management |
| 81 | +
|
| 82 | +Models follow the `State` pattern where each model type extends the base `State` class: |
| 83 | +
|
| 84 | +```cpp |
| 85 | +struct State { |
| 86 | + virtual DeviceSpan<float> Run(int total_length, |
| 87 | + DeviceSpan<int32_t>& next_tokens) = 0; |
| 88 | + virtual void RewindTo(size_t index) {} // For session continuation |
| 89 | +}; |
| 90 | +``` |
| 91 | + |
| 92 | +### Error Handling Convention |
| 93 | + |
| 94 | +Use `OgaCheckResult()` wrapper for C API error propagation: |
| 95 | + |
| 96 | +```cpp |
| 97 | +OgaCheckResult(OgaCreateModel(model_path, &model)); // Throws std::runtime_error |
| 98 | +``` |
| 99 | +
|
| 100 | +### Memory Management |
| 101 | +
|
| 102 | +- **DeviceSpan/DeviceBuffer**: Device-agnostic memory abstractions |
| 103 | +- **std::unique_ptr with custom deleters**: For C API resource cleanup |
| 104 | +- **LeakChecked<T>**: Debug-mode leak detection for core types |
| 105 | +
|
| 106 | +## Critical Integration Points |
| 107 | +
|
| 108 | +### ONNX Runtime Dependency Management |
| 109 | +
|
| 110 | +ADO pipelines obtain ORT lib/headers via three methods: |
| 111 | +1. **Explicit `ORT_HOME`** - Pipeline provides pre-built ORT artifacts (preferred) |
| 112 | +2. **Auto-download via CMake** - `cmake/ortlib.cmake` fetches from ORT-Nightly feed when `ORT_HOME` unset |
| 113 | +3. **Python build driver** - `tools/python/util/dependency_resolver.py` downloads NuGet packages |
| 114 | +
|
| 115 | +### Model Loading Pipeline |
| 116 | +
|
| 117 | +1. **Config parsing** (`src/config.cpp`) - Reads `genai_config.json` model metadata |
| 118 | +2. **ONNX session creation** via `onnxruntime_api.h` wrappers |
| 119 | +3. **Device interface selection** based on provider availability |
| 120 | +4. **KV cache initialization** (`src/models/kv_cache.cpp`) for transformer models |
| 121 | +
|
| 122 | +### Multi-Modal Support |
| 123 | +
|
| 124 | +Vision models (Phi-Vision) use separate processor classes: |
| 125 | +- `PhiImageProcessor` - Image tokenization and preprocessing |
| 126 | +- `MultiModalProcessor` - Coordinates text/image inputs |
| 127 | +
|
| 128 | +### Execution Provider Detection |
| 129 | +
|
| 130 | +Hardware acceleration auto-detection follows this priority: |
| 131 | +1. CUDA (if `USE_CUDA=ON` and CUDA runtime available) |
| 132 | +2. DirectML (Windows, if `USE_DML=ON`) |
| 133 | +3. CPU fallback |
| 134 | +
|
| 135 | +## Project-Specific Gotchas |
| 136 | +
|
| 137 | +### Windows-Specific Build Requirements |
| 138 | +
|
| 139 | +- **Visual Studio 2022** required for C++20 features |
| 140 | +- **WinML integration** requires specific NuGet package versions (see `cmake/nuget.cmake`) |
| 141 | +- **Cross-compilation** for ARM64/ARM64EC supported via CMake platform flags |
| 142 | +
|
| 143 | +### Model Compatibility Matrix |
| 144 | +
|
| 145 | +The repo supports specific model architectures - check `src/models/model_type.h` for the canonical list. New models require: |
| 146 | +1. Config template in model directory |
| 147 | +2. State implementation extending base `State` class |
| 148 | +3. Optional custom processors for multi-modal inputs |
| 149 | +
|
| 150 | +### Performance Considerations |
| 151 | +
|
| 152 | +- **KV caching** is automatically managed but can be configured via `runtime_settings.cpp` |
| 153 | +- **Continuous decoding** (session continuation) requires careful state management |
| 154 | +- **Multi-LoRA** adapters use separate weight loading in `src/models/adapters.cpp` |
| 155 | +
|
| 156 | +## Testing Strategy |
| 157 | +
|
| 158 | +Tests are organized by language binding: |
| 159 | +- **C++ tests**: `test/` directory, focused on core API validation |
| 160 | +- **Python tests**: `test/python/`, includes end-to-end model testing |
| 161 | +- **Platform tests**: Android/iOS tests run via emulator/simulator |
| 162 | +
|
| 163 | +Always test with actual model files from `test/test_models/` directory rather than mock data. |
0 commit comments