Skip to content

Commit 49f9271

Browse files
committed
Implement context-length dependent KV-cache and Compute Buffer aware layer distribution for heterogeneous multi-GPU inference. Solves the problem of attemtping to run setups with different VRAM (e.g. 24GB cards with 6GB cards); previously layers were assigned without accounting for compute buffer, causing failure when one or more smaller GPUs could not hold the compute buffer.
- Add requested_n_ctx parameter to llama_model_params - Implement 3-pass allocation algorithm accounting for compute buffers - Add device exclusion for insufficient memory (GPUs too small to allocate 1 layer + KV_cache + compute buffer excluded) - Add layer redistribution to make equitable use of included GPUs (may not be truly optimal)
1 parent de56944 commit 49f9271

File tree

3 files changed

+312
-0
lines changed

3 files changed

+312
-0
lines changed

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
11071107
mparams.use_mmap = params.use_mmap;
11081108
mparams.use_mlock = params.use_mlock;
11091109
mparams.check_tensors = params.check_tensors;
1110+
mparams.requested_n_ctx = params.n_ctx;
11101111

11111112
if (params.kv_overrides.empty()) {
11121113
mparams.kv_overrides = NULL;

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,9 @@ extern "C" {
322322
// override key-value pairs of the model meta data
323323
const struct llama_model_kv_override * kv_overrides;
324324

325+
// expected context size for memory allocation planning (0 = auto)
326+
uint32_t requested_n_ctx;
327+
325328
// Keep the booleans together to avoid misalignment during copy-by-value.
326329
bool vocab_only; // only load the vocabulary, no weights
327330
bool use_mmap; // use mmap if possible

src/llama-model.cpp

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
#include <cassert>
1818
#include <cmath>
1919
#include <cfloat>
20+
#include <cstdlib>
2021
#include <cstring>
2122
#include <cmath>
2223
#include <functional>
2324
#include <map>
25+
#include <numeric>
2426
#include <regex>
2527
#include <sstream>
2628
#include <stdexcept>
@@ -1580,6 +1582,311 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
15801582
splits[i] /= split_sum;
15811583
}
15821584

1585+
// KV-cache aware layer distribution for heterogeneous GPUs
1586+
if (all_zero && n_devices() > 1 && split_mode == LLAMA_SPLIT_MODE_LAYER) {
1587+
// Determine context size for memory planning
1588+
uint32_t n_ctx_for_kv = 0;
1589+
if (params.requested_n_ctx > 0) {
1590+
// Use the explicitly requested context size from model params
1591+
n_ctx_for_kv = params.requested_n_ctx;
1592+
LLAMA_LOG_INFO("%s: Using requested_n_ctx=%u for KV cache calculation\n",
1593+
__func__, n_ctx_for_kv);
1594+
} else {
1595+
// Use a conservative default for memory planning
1596+
n_ctx_for_kv = std::min(32768u, hparams.n_ctx_train);
1597+
LLAMA_LOG_INFO("%s: Using default n_ctx=%u for KV cache calculation (training context: %u)\n",
1598+
__func__, n_ctx_for_kv, hparams.n_ctx_train);
1599+
LLAMA_LOG_INFO("%s: (set requested_n_ctx in model params to match your actual context size)\n", __func__);
1600+
}
1601+
1602+
// Only apply KV-aware distribution if we have a valid context size
1603+
if (n_ctx_for_kv > 0 && n_gpu_layers > 0) {
1604+
LLAMA_LOG_INFO("%s: Implementing KV-cache aware layer distribution\n", __func__);
1605+
1606+
// Calculate memory requirements per layer
1607+
const int64_t n_head_kv = hparams.n_head_kv();
1608+
const int64_t n_embd_head = hparams.n_embd_head_k;
1609+
const int64_t n_embd_kv = n_embd_head * n_head_kv;
1610+
1611+
// KV cache element size (typically f16 = 2 bytes, but can be quantized)
1612+
const size_t kv_size_element = 2; // sizeof(ggml_fp16_t)
1613+
1614+
// Total KV cache size for all layers (K and V)
1615+
// KV cache = 2 (K+V) * n_ctx * n_layers * n_embd_kv * element_size
1616+
const size_t kv_cache_size_total = 2ULL * n_ctx_for_kv * n_layer * n_embd_kv * kv_size_element;
1617+
1618+
// Estimate model weight size per layer
1619+
const size_t model_size_total = ml.n_bytes;
1620+
const size_t weight_size_per_layer = model_size_total / n_layer;
1621+
1622+
// Calculate actual compute buffer size based on attention matrix requirements
1623+
// Attention matrix: n_kv × n_ubatch × n_head × sizeof(float)
1624+
// This is the dominant memory consumer during inference
1625+
const int64_t n_head = hparams.n_head();
1626+
const size_t n_ubatch = 512; // Default physical batch size (from context params)
1627+
const size_t compute_buffer_size = n_ctx_for_kv * n_ubatch * n_head * sizeof(float);
1628+
const size_t min_overhead = 512ULL * 1024 * 1024; // 512MB base overhead
1629+
1630+
LLAMA_LOG_INFO("%s: Compute buffer size: %.2f MB (context=%u, ubatch=%zu, heads=%lld)\n",
1631+
__func__,
1632+
compute_buffer_size / 1024.0 / 1024.0,
1633+
n_ctx_for_kv, n_ubatch, (long long)n_head);
1634+
1635+
// For memory calculation, we need to account for KV cache being shared across layers on each device
1636+
// We'll calculate this dynamically during layer assignment
1637+
1638+
LLAMA_LOG_INFO("%s: Per-layer memory: weights=%.2f MB\n",
1639+
__func__,
1640+
weight_size_per_layer / 1024.0 / 1024.0);
1641+
LLAMA_LOG_INFO("%s: Total KV cache size: %.2f MB\n",
1642+
__func__,
1643+
kv_cache_size_total / 1024.0 / 1024.0);
1644+
1645+
// Get memory info and calculate layer assignments
1646+
std::vector<int> layers_per_gpu(n_devices(), 0);
1647+
std::vector<size_t> gpu_free_memory(n_devices());
1648+
1649+
// Get free memory for each device and check if they can handle compute buffers
1650+
std::vector<bool> device_excluded(n_devices(), false);
1651+
for (size_t i = 0; i < n_devices(); ++i) {
1652+
ggml_backend_dev_t dev = devices[i];
1653+
size_t total, free;
1654+
ggml_backend_dev_memory(dev, &free, &total);
1655+
gpu_free_memory[i] = free;
1656+
1657+
// Check if device can handle minimum requirements (1 layer + compute buffer + KV cache)
1658+
size_t min_kv_cache = kv_cache_size_total / n_devices(); // Conservative estimate
1659+
size_t min_required = weight_size_per_layer + min_kv_cache + compute_buffer_size + min_overhead;
1660+
1661+
if (free < min_required) {
1662+
device_excluded[i] = true;
1663+
LLAMA_LOG_WARN("%s: Device %zu [%s]: %.2f MB free - excluding (needs %.2f MB minimum)\n",
1664+
__func__, i, ggml_backend_dev_name(dev),
1665+
free / 1024.0 / 1024.0, min_required / 1024.0 / 1024.0);
1666+
}
1667+
}
1668+
1669+
// Estimate total memory requirements and warn if insufficient
1670+
size_t total_gpu_memory = 0;
1671+
for (size_t i = 0; i < n_devices(); ++i) {
1672+
total_gpu_memory += gpu_free_memory[i];
1673+
}
1674+
1675+
// Rough estimate: KV cache + model weights + compute buffers (conservative estimate)
1676+
size_t estimated_compute_buffers = kv_cache_size_total; // Compute buffers often similar to KV cache size
1677+
size_t estimated_total_needed = kv_cache_size_total + model_size_total + estimated_compute_buffers;
1678+
1679+
if (estimated_total_needed > total_gpu_memory) {
1680+
LLAMA_LOG_WARN("%s: Memory estimate: %.2f GB needed vs %.2f GB available\n",
1681+
__func__,
1682+
estimated_total_needed / 1024.0 / 1024.0 / 1024.0,
1683+
total_gpu_memory / 1024.0 / 1024.0 / 1024.0);
1684+
LLAMA_LOG_WARN("%s: Context size may be too large for available memory\n", __func__);
1685+
}
1686+
1687+
// Sort devices by available memory (largest first), excluding unusable devices
1688+
std::vector<size_t> gpu_indices;
1689+
for (size_t i = 0; i < n_devices(); ++i) {
1690+
if (!device_excluded[i]) {
1691+
gpu_indices.push_back(i);
1692+
}
1693+
}
1694+
std::sort(gpu_indices.begin(), gpu_indices.end(),
1695+
[&gpu_free_memory](size_t a, size_t b) {
1696+
return gpu_free_memory[a] > gpu_free_memory[b];
1697+
});
1698+
1699+
if (gpu_indices.empty()) {
1700+
LLAMA_LOG_ERROR("%s: No GPUs have sufficient memory for compute buffers\n", __func__);
1701+
// Fall back to original allocation
1702+
return true;
1703+
}
1704+
1705+
// Assign layers greedily to GPUs with most memory first
1706+
int act_gpu_layers = n_gpu_layers; // Local copy that can be modified
1707+
int remaining_layers = act_gpu_layers;
1708+
1709+
// First pass: assign layers based on weights only (KV cache and compute buffers handled separately)
1710+
size_t weight_per_layer = weight_size_per_layer;
1711+
1712+
for (size_t idx : gpu_indices) {
1713+
// Reserve memory for compute buffer and base overhead
1714+
size_t reserved = compute_buffer_size + min_overhead;
1715+
if (gpu_free_memory[idx] <= reserved) {
1716+
LLAMA_LOG_WARN("%s: Device %zu [%s]: %zu MB free, can't fit compute buffer (%.2f MB)\n",
1717+
__func__, idx, ggml_backend_dev_name(devices[idx]),
1718+
gpu_free_memory[idx] / 1024 / 1024,
1719+
reserved / 1024.0 / 1024.0);
1720+
continue;
1721+
}
1722+
1723+
size_t available_for_model = gpu_free_memory[idx] - reserved;
1724+
int layers_that_fit = available_for_model / weight_per_layer;
1725+
1726+
if (layers_that_fit > 0 && remaining_layers > 0) {
1727+
int layers_to_assign = std::min(layers_that_fit, remaining_layers);
1728+
layers_per_gpu[idx] = layers_to_assign;
1729+
remaining_layers -= layers_to_assign;
1730+
1731+
LLAMA_LOG_INFO("%s: Device %zu [%s]: %zu MB free, assigned %d layers (%.2f MB weights, %.2f MB compute buffer)\n",
1732+
__func__, idx, ggml_backend_dev_name(devices[idx]),
1733+
gpu_free_memory[idx] / 1024 / 1024,
1734+
layers_per_gpu[idx],
1735+
(layers_to_assign * weight_per_layer) / 1024.0 / 1024.0,
1736+
compute_buffer_size / 1024.0 / 1024.0);
1737+
} else {
1738+
LLAMA_LOG_WARN("%s: Device %zu [%s]: %zu MB free, assigned 0 layers (need %.2f MB per layer + %.2f MB compute buffer)\n",
1739+
__func__, idx, ggml_backend_dev_name(devices[idx]),
1740+
gpu_free_memory[idx] / 1024 / 1024,
1741+
weight_per_layer / 1024.0 / 1024.0,
1742+
compute_buffer_size / 1024.0 / 1024.0);
1743+
}
1744+
}
1745+
1746+
// Second pass: iteratively check if KV cache can fit proportionally
1747+
bool kv_fit_check_needed = (remaining_layers == 0);
1748+
int iterations = 0;
1749+
const int max_iterations = 10;
1750+
1751+
while (kv_fit_check_needed && iterations < max_iterations) {
1752+
kv_fit_check_needed = false;
1753+
iterations++;
1754+
1755+
// Calculate current total assigned layers
1756+
int total_assigned = 0;
1757+
for (size_t idx = 0; idx < n_devices(); ++idx) {
1758+
total_assigned += layers_per_gpu[idx];
1759+
}
1760+
1761+
if (total_assigned == 0) break;
1762+
1763+
// Check KV cache distribution for each device
1764+
for (size_t idx = 0; idx < n_devices(); ++idx) {
1765+
if (layers_per_gpu[idx] > 0) {
1766+
double layer_ratio = (double)layers_per_gpu[idx] / total_assigned;
1767+
size_t kv_cache_for_device = (size_t)(kv_cache_size_total * layer_ratio);
1768+
size_t weights = layers_per_gpu[idx] * weight_per_layer;
1769+
size_t total_memory_needed = weights + kv_cache_for_device + compute_buffer_size + min_overhead;
1770+
1771+
if (total_memory_needed > gpu_free_memory[idx]) {
1772+
// Device can't fit current allocation, reduce layers
1773+
size_t available_memory = gpu_free_memory[idx];
1774+
if (available_memory > min_overhead + kv_cache_for_device + compute_buffer_size) {
1775+
size_t available_for_weights = available_memory - min_overhead - kv_cache_for_device - compute_buffer_size;
1776+
int new_layer_count = available_for_weights / weight_per_layer;
1777+
new_layer_count = std::max(0, new_layer_count);
1778+
1779+
if (new_layer_count < layers_per_gpu[idx]) {
1780+
LLAMA_LOG_WARN("%s: Device %zu: Reducing layers from %d to %d due to KV cache requirements (%.2f MB KV cache)\n",
1781+
__func__, idx, layers_per_gpu[idx], new_layer_count,
1782+
kv_cache_for_device / 1024.0 / 1024.0);
1783+
remaining_layers += layers_per_gpu[idx] - new_layer_count;
1784+
layers_per_gpu[idx] = new_layer_count;
1785+
kv_fit_check_needed = true;
1786+
}
1787+
} else {
1788+
// Device can't even fit the minimum requirements
1789+
LLAMA_LOG_WARN("%s: Device %zu: Removing all %d layers - insufficient memory for KV cache\n",
1790+
__func__, idx, layers_per_gpu[idx]);
1791+
remaining_layers += layers_per_gpu[idx];
1792+
layers_per_gpu[idx] = 0;
1793+
kv_fit_check_needed = true;
1794+
}
1795+
}
1796+
}
1797+
}
1798+
}
1799+
1800+
// Third pass: redistribute any remaining layers to devices with available capacity
1801+
if (remaining_layers > 0) {
1802+
LLAMA_LOG_INFO("%s: Attempting to redistribute %d remaining layers\n", __func__, remaining_layers);
1803+
1804+
// Calculate current memory usage for each device that has layers assigned
1805+
for (size_t idx : gpu_indices) {
1806+
if (layers_per_gpu[idx] > 0 && remaining_layers > 0) {
1807+
// Calculate current memory usage
1808+
int current_assigned = 0;
1809+
for (size_t i = 0; i < n_devices(); ++i) {
1810+
current_assigned += layers_per_gpu[i];
1811+
}
1812+
1813+
double layer_ratio = (double)layers_per_gpu[idx] / current_assigned;
1814+
size_t current_kv_cache = (size_t)(kv_cache_size_total * layer_ratio);
1815+
size_t current_weights = layers_per_gpu[idx] * weight_per_layer;
1816+
size_t current_usage = current_weights + current_kv_cache + compute_buffer_size + min_overhead;
1817+
1818+
if (gpu_free_memory[idx] > current_usage) {
1819+
// Calculate how many additional layers could fit
1820+
// We need to account for proportional increase in KV cache
1821+
int additional_layers = 0;
1822+
for (int test_layers = 1; test_layers <= remaining_layers; test_layers++) {
1823+
int new_total_layers = layers_per_gpu[idx] + test_layers;
1824+
int new_total_assigned = current_assigned + test_layers;
1825+
double new_layer_ratio = (double)new_total_layers / new_total_assigned;
1826+
size_t new_kv_cache = (size_t)(kv_cache_size_total * new_layer_ratio);
1827+
size_t new_weights = new_total_layers * weight_per_layer;
1828+
size_t new_total_usage = new_weights + new_kv_cache + compute_buffer_size + min_overhead;
1829+
1830+
if (new_total_usage <= gpu_free_memory[idx]) {
1831+
additional_layers = test_layers;
1832+
} else {
1833+
break;
1834+
}
1835+
}
1836+
1837+
if (additional_layers > 0) {
1838+
int layers_to_add = std::min(additional_layers, remaining_layers);
1839+
layers_per_gpu[idx] += layers_to_add;
1840+
remaining_layers -= layers_to_add;
1841+
1842+
LLAMA_LOG_INFO("%s: Device %zu [%s]: redistributed %d additional layers (total now %d)\n",
1843+
__func__, idx, ggml_backend_dev_name(devices[idx]),
1844+
layers_to_add, layers_per_gpu[idx]);
1845+
}
1846+
}
1847+
}
1848+
}
1849+
}
1850+
1851+
// Warn if we couldn't place all layers
1852+
if (remaining_layers > 0) {
1853+
LLAMA_LOG_ERROR("%s: WARNING: Could not assign %d layers to GPUs. Consider:\n",
1854+
__func__, remaining_layers);
1855+
LLAMA_LOG_ERROR("%s: - Reducing context size (current: %u)\n",
1856+
__func__, n_ctx_for_kv);
1857+
LLAMA_LOG_ERROR("%s: - Using fewer layers (-ngl)\n", __func__);
1858+
LLAMA_LOG_ERROR("%s: - Adding more GPU memory\n", __func__);
1859+
1860+
// Put remaining layers on CPU (will be updated below)
1861+
}
1862+
1863+
// Convert layer counts to split ratios
1864+
splits.clear();
1865+
splits.resize(n_devices());
1866+
float cumsum = 0.0f;
1867+
1868+
// Calculate total layers actually assigned
1869+
int total_assigned_layers = 0;
1870+
for (size_t i = 0; i < n_devices(); ++i) {
1871+
total_assigned_layers += layers_per_gpu[i];
1872+
}
1873+
1874+
// Update act_gpu_layers to match what we actually assigned
1875+
act_gpu_layers = total_assigned_layers;
1876+
1877+
for (size_t i = 0; i < n_devices(); ++i) {
1878+
cumsum += (float)layers_per_gpu[i] / act_gpu_layers;
1879+
splits[i] = cumsum;
1880+
}
1881+
1882+
LLAMA_LOG_INFO("%s: Final split ratios: ", __func__);
1883+
for (size_t i = 0; i < n_devices(); ++i) {
1884+
LLAMA_LOG_CONT("%.3f ", splits[i]);
1885+
}
1886+
LLAMA_LOG_CONT("\n");
1887+
}
1888+
}
1889+
15831890
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
15841891
if (cpu_dev == nullptr) {
15851892
throw std::runtime_error(format("%s: no CPU backend found", __func__));
@@ -14837,6 +15144,7 @@ llama_model_params llama_model_default_params() {
1483715144
/*.progress_callback =*/ nullptr,
1483815145
/*.progress_callback_user_data =*/ nullptr,
1483915146
/*.kv_overrides =*/ nullptr,
15147+
/*.requested_n_ctx =*/ 0,
1484015148
/*.vocab_only =*/ false,
1484115149
/*.use_mmap =*/ true,
1484215150
/*.use_mlock =*/ false,

0 commit comments

Comments
 (0)