Skip to content

Add continuous batching support for concurrent LLM inference#150

Open
ronaldmannak wants to merge 101 commits intoml-explore:mainfrom
PicoMLX:batching3
Open

Add continuous batching support for concurrent LLM inference#150
ronaldmannak wants to merge 101 commits intoml-explore:mainfrom
PicoMLX:batching3

Conversation

@ronaldmannak
Copy link
Copy Markdown
Contributor

Proposed changes

  • Continuous batching engine that transparently serves multiple concurrent requests in a shared decode loop, with zero overhead for single requests
  • BatchKVCache & BatchRotatingKVCache implementations with left-padding strategy for batched attention across sequences of different lengths
  • InferenceScheduler actor that starts in single-request mode (TokenIterator) and automatically upgrades to BatchTokenIterator when a second request arrives; 3rd+ requests join the existing batch on the fly
  • LRUPromptCache with trie-based prefix matching and LRU eviction for reusing KV state across requests with shared prefixes
  • Batch-aware RoPE via the new applyRotaryPosition helper — all 45 modified MLXLLM model files with RoPE-based attention paths now use batch-aware position handling that works in both single and batched modes
  • Automatic fallback for incompatible requests (VLMs, hybrid/SSM models using MambaCache or CacheList, and quantized KV-cache requests) — routed to the single-request path with no caller changes needed

Design

This PR adds batching support for text-generation LLMs. VLM requests are not batched in this PR, and embedding models are out of scope for this batching path.

Most LLM model types are supported. Batched generation is currently limited to models whose cache stack uses KVCacheSimple and/or RotatingKVCache, and is unavailable for models that use MambaCache or CacheList. Requests using those models automatically fall back to the single inference path. Support for MambaCache and CacheList will be added in a separate PR.

QuantizedKVCache is also not supported for batching. Requests with kvBits != nil, or requests that already carry quantized KV cache state, automatically fall back to the single inference path.

The batching system is opt-in via ModelContainer.scheduler:

let container = ModelContainer(context: context)
container.scheduler = InferenceScheduler()
// Existing generate() calls now batch transparently when the request is compatible

Incompatible models that use MambaCache and/or CacheList are:

  • falcon_h1
  • baichuan_m1
  • lfm2
  • lfm2_moe
  • granitemoehybrid
  • qwen3_next
  • qwen3_5, qwen3_5_text, qwen3_5_moe
  • nemotron_h
  • jamba_3b

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

ronaldmannak and others added 24 commits March 27, 2026 15:01
Set up .factory/ directory with worker skills, services manifest,
init script, and library knowledge files for the batching mission.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…hing

Add Libraries/MLXLMCommon/Batching/BatchKVCache.swift porting Python mlx-lm's
BatchKVCache. Includes: init with leftPadding, update with step-based buffer
allocation, filter(batchIndices:) with left-shift optimization, extend(other:)
with right-justification, extract(idx:) returning KVCacheSimple with padding
stripped, merge([KVCache]) class method, fromSingle/toSingle conversion,
state serialization, and empty batch handling.

Add comprehensive XCTest suite in Tests/MLXLMTests/BatchKVCacheTests.swift
with 22 test cases covering all validation contract assertions (VAL-CACHE-001
through VAL-CACHE-021).

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…taryPosition helper

- Add leftPadding parameter to createCausalMask() for per-sequence padding masks (backward compatible)
- Implement makeMask() on BatchKVCache that always masks padding (including n=1 decode steps)
- Create BatchPositionedKVCache protocol with batchOffset for per-sequence RoPE offsets
- Implement applyRotaryPosition() dispatching to ArrayOffsetLayer for batch, OffsetLayer for single
- Add isBatchCompatible() detection for CacheList, MambaCache, and QuantizedKVCache
- Make BatchKVCache conform to BatchPositionedKVCache
- Add 18 unit tests covering all validation assertions

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Port BatchRotatingKVCache from Python mlx-lm for models using sliding-window
attention. Supports init with maxSize/leftPadding, multi-token concat path,
single-token in-place rotation, temporal ordering, filter/extend/extract,
merge from RotatingKVCache instances (with maxSize mismatch rejection),
makeMask with window and left-padding, and fromSingle/toSingle conversions.
Conforms to BatchPositionedKVCache protocol. Extract returns RotatingKVCache.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
The MLX Metal shader library (.metallib) is not bundled in SPM debug
builds, causing tests that trigger GPU evaluation to crash the test
runner. This adds an MLXMetalGuard helper that probes Metal availability
using withError/eval, and XCTSkipUnless/.enabled(if:) guards to all
MLX-dependent tests across the test suite.

Changes:
- New MLXMetalGuard.swift with cached Metal availability detection
- skipIfMetalUnavailable() helper for XCTest-based tests
- BatchKVCacheTests: all 22 tests guarded, fixed always-true 'is' check
- BatchMaskingAndPositionTests: 11 Metal tests guarded, fixed unused k/v bindings
- BatchRotatingKVCacheTests: all 22 tests guarded, fixed always-true 'is' checks
- KVCacheTests: .enabled(if:) guard for Swift Testing
- ChatSessionTests, EvalTests, SampleTests, NemotronHTests,
  MediaProcessingTests: guarded Metal-dependent tests

swift test --filter MLXLMTests now exits with code 0 (117 skipped, 20 pass).

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Auto-format 4 batch files using swift-format: fix import ordering
(@testable imports after regular imports) in 3 test files, and fix
line length violation in BatchKVCache.swift.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Capture the milestone synthesis, feature review reports, and shared validation knowledge for the next fix round.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…dable warning

- State getter now always includes batchOffsets and leftPadding even when
  keys/values are nil (fresh cache or emptied by filter([])). Setter handles
  both 2-element (empty) and 4-element (populated) state arrays.
- makeMask() now uses _idx directly as the offset (pre-update value) instead
  of _idx - n, aligning with how models call makeMask before cache.update().
- KVCacheTests.swift closure arguments annotated with @sendable to fix
  Swift Testing strict concurrency warning.
- Added round-trip tests for fresh and filter-emptied caches, plus makeMask
  pre-update call order tests.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…atchRotatingKVCache

- Implement prepare(leftPadding:lengths:rightPadding:) and finalize() methods
  matching Python mlx-lm's BatchRotatingKVCache for cached-prompt batch prefill
- Add dynamicRoll helper for per-batch element rolling
- Preserve RotatingKVCache.keep through merge/extract/fromSingle/toSingle paths
- Reject caches with different keep values in merge (same as maxSize rejection)
- Make RotatingKVCache.keep internal for cross-file access within module
- Update metaState serialization to include keep value
- Add _lengths state and integration with updateConcat/updateInPlace
- Add 15 new tests: keep round-trip, prepare/finalize, filter-extend with keep

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…g-window overflow

- trim(): Preserves first keep positions, only trims from window portion
- updateInPlace(): Wraps _idx to keep (not 0) so keep positions never overwritten
- temporalOrder(): Handles keep prefix correctly during rotation unrolling
- makeMask(): Rolls only the window portion of the mask when keep > 0
- extract(): Uses keep-aware rolling for rotated cache extraction
- Added 6 tests covering overflow preservation, wrap semantics, temporal
  ordering with keep, merge-extract round-trip after overflow, keep=0
  regression, and multiple rotation cycles

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Clamp leftPadding to non-negative (max(0, rawPadding)) before slicing
in extract() to prevent invalid array indices when the rotating cache
has overflowed. Updated testKeepOverflowMergeExtractRoundTrip to assert
actual key/value tensor contents, and added two new tests covering
negative leftPadding scenarios with and without keep prefix.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
The mask key dimension must equal _idx (total cached positions), not
_idx + n. Pass offset = _idx - n to createCausalMask so the produced
key-width is _idx. Fixes VAL-CACHE-011 (prefill doubling width) and
VAL-CACHE-020 (decode adding extra column).

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…during overflow

When keep > 0 and sequences have left-padding, the global keep zone
(positions 0..<keep) could contain padding zeros instead of actual
keep-prefix tokens. During rotation, writes into the window zone would
overwrite the real keep tokens that were shifted rightward by padding.

Fix: at the first rotation boundary, roll away each sequence's
left-padding via dynamicRoll so that per-sequence data starts at
position 0. This aligns the keep-prefix with the global keep zone,
preventing data corruption. Subsequent wraps are no-ops since
leftPadding is already ≤ 0.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Port Python mlx-lm BatchGenerator to Swift. Includes:
- PendingPrompt: queued prompt with tokens, sampler, processor, maxTokens
- ActiveBatch: holds UIDs, current tokens, caches, per-request state
- BatchTokenIterator: insert/next/remove/close API with prefill scheduling
- Left-padding, prompt sorting by length, chunked prefill
- Per-request sampler and LogitProcessor support
- 16 unit tests covering VAL-ENGINE-001 through VAL-ENGINE-012

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…atchTokenIterator

- Fix LogitProcessor lifecycle: add prompt() initialization during prefill and
  didSample() callback after sampling so penalty state tracks correctly
- Make step() accept processors as inout for proper mutation of penalty state
- Add 10 new tests: per-request sampler independence, processor state isolation,
  batch-vs-single numerical correctness with ArgMax, concurrent safety, asyncEval
  pipelining, processor prompt/didSample verification

Fulfills: VAL-ENGINE-013, VAL-ENGINE-014, VAL-ENGINE-015, VAL-ENGINE-016

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
…enIterator

- Decouple completionBatchSize from prefillBatchSize (no longer clamped)
- Admit min(freeSlots, prefillBatchSize, pendingCount) prompts per step
  so free decode slots are filled even when < prefillBatchSize available
- Add NSLock-based thread safety around all shared mutable state
- Mark BatchTokenIterator as @unchecked Sendable, Response as Sendable
- Update concurrency test with structural invariant assertions
- Add tests for independent batch sizes and partial admission

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
ronaldmannak and others added 27 commits March 27, 2026 15:05
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Remove the fallback path that skipped cache-type assertions when the
scheduler state was missed. Two changes:

1. Add testFromSinglePreservesRotatingKVCacheData: tests the
   BatchRotatingKVCache.fromSingle() conversion directly at the cache
   level with known data, verifying maxSize, keep, non-nil keys/values,
   correct offset, and data dimensions.

2. Rewrite testUpgradePreservesRotatingKVCacheState: use maxTokens:1000
   for the first request to guarantee it is still active when the second
   request arrives, ensuring the scheduler always reaches batched state.
   Remove the else branch that fell back to token-only checks.

The test now ALWAYS verifies cache layer types (BatchKVCache for
KVCacheSimple layers, BatchRotatingKVCache for RotatingKVCache layers).

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Store cache entries under inputTokens + generatedTokens instead of just
inputTokens, so the trie key depth matches the actual KV cache depth.
This matches upstream mlx-lm behavior where the prompt cache stores the
full context so prefix matches work correctly on subsequent lookups.

Changes:
- Single path: collect generated token IDs and write back under full sequence
- Batch path: track per-UID generated tokens and write back under full sequence
- Fix _deepCopy crash when cache has empty state (nil keys/values)
- Add regression tests: same prompt twice gets cache hit, key depth matches cache
- Update existing write-back tests for new key format

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Replace the fixed 50ms delay with a synchronization mechanism that waits
for the first stream to produce at least one token before submitting
the second request. This guarantees the first request is actively
generating when the upgrade triggers, eliminating timing-dependent
flakiness.

Also remove assertions about non-nil keys/values and offset > 0 in the
upgraded BatchRotatingKVCache, since the mock model does not call
cache.update(). Data preservation is already verified by the separate
testFromSinglePreservesRotatingKVCacheData test.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Change mock model formula from (lastToken + 1) % vocabSize to
(sum of input tokens % (vocabSize - 1)) + 1, guaranteeing output
tokens are always in range [1, vocabSize-1] and never hit EOS.
Keeps existing AsyncStream synchronization for deterministic upgrade.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
When the first request is upgraded from single to batch mode, its tokens
generated on the single path were not included in the batch write-back key.
This caused the trie key to be shorter than the actual KV cache depth.

Fix: carry generatedTokenIds through LiveIteratorState into the batch loop
and seed the first request's token list with those pre-upgrade tokens.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Move isRotating type check inside the per-layer loop in
processPartialCacheHits() so each layer is individually dispatched
to the correct batch cache path. Previously the blanket first-layer
check silently dropped RotatingKVCache data for mixed-layer models
like Gemma3. Add regression test with MockMixedLayerCacheModel.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Make the prompt-cache batching mock models advance KV caches so the mixed-depth cached-prefill test exercises the real final-cache extraction path and keeps cache metadata aligned. Strengthen the integration test to assert each finished request returns an extractable final cache.

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
@ronaldmannak ronaldmannak marked this pull request as ready for review March 28, 2026 21:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant