Skip to content

Commit 349fd29

Browse files
borebotNexesenex
authored andcommitted
Implement context-length dependent KV-cache and Compute Buffer aware layer distribution for heterogeneous multi-GPU inference.
Solves the problem of attemtping to run setups with different VRAM (e.g. 24GB cards with 6GB cards); previously layers were assigned without accounting for compute buffer, causing failure when one or more smaller GPUs could not hold the compute buffer. - Add requested_n_ctx parameter to llama_model_params - Implement 3-pass allocation algorithm accounting for compute buffers - Add device exclusion for insufficient memory (GPUs too small to allocate 1 layer + KV_cache + compute buffer excluded) - Add layer redistribution to make equitable use of included GPUs (may not be truly optimal)
1 parent 6d6ec07 commit 349fd29

File tree

3 files changed

+312
-0
lines changed

3 files changed

+312
-0
lines changed

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
11151115
mparams.use_mmap = params.use_mmap;
11161116
mparams.use_mlock = params.use_mlock;
11171117
mparams.check_tensors = params.check_tensors;
1118+
mparams.requested_n_ctx = params.n_ctx;
11181119

11191120
if (params.kv_overrides.empty()) {
11201121
mparams.kv_overrides = NULL;

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ extern "C" {
375375
// override key-value pairs of the model meta data
376376
const struct llama_model_kv_override * kv_overrides;
377377

378+
// expected context size for memory allocation planning (0 = auto)
379+
uint32_t requested_n_ctx;
380+
378381
// Keep the booleans together to avoid misalignment during copy-by-value.
379382
bool vocab_only; // only load the vocabulary, no weights
380383
bool use_mmap; // use mmap if possible

src/llama-model.cpp

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
#include <cassert>
1818
#include <cmath>
1919
#include <cfloat>
20+
#include <cstdlib>
2021
#include <cstring>
2122
#include <cmath>
2223
#include <functional>
2324
#include <map>
25+
#include <numeric>
2426
#include <regex>
2527
#include <sstream>
2628
#include <stdexcept>
@@ -1621,6 +1623,311 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
16211623
splits[i] /= split_sum;
16221624
}
16231625

1626+
// KV-cache aware layer distribution for heterogeneous GPUs
1627+
if (all_zero && n_devices() > 1 && split_mode == LLAMA_SPLIT_MODE_LAYER) {
1628+
// Determine context size for memory planning
1629+
uint32_t n_ctx_for_kv = 0;
1630+
if (params.requested_n_ctx > 0) {
1631+
// Use the explicitly requested context size from model params
1632+
n_ctx_for_kv = params.requested_n_ctx;
1633+
LLAMA_LOG_INFO("%s: Using requested_n_ctx=%u for KV cache calculation\n",
1634+
__func__, n_ctx_for_kv);
1635+
} else {
1636+
// Use a conservative default for memory planning
1637+
n_ctx_for_kv = std::min(32768u, hparams.n_ctx_train);
1638+
LLAMA_LOG_INFO("%s: Using default n_ctx=%u for KV cache calculation (training context: %u)\n",
1639+
__func__, n_ctx_for_kv, hparams.n_ctx_train);
1640+
LLAMA_LOG_INFO("%s: (set requested_n_ctx in model params to match your actual context size)\n", __func__);
1641+
}
1642+
1643+
// Only apply KV-aware distribution if we have a valid context size
1644+
if (n_ctx_for_kv > 0 && n_gpu_layers > 0) {
1645+
LLAMA_LOG_INFO("%s: Implementing KV-cache aware layer distribution\n", __func__);
1646+
1647+
// Calculate memory requirements per layer
1648+
const int64_t n_head_kv = hparams.n_head_kv();
1649+
const int64_t n_embd_head = hparams.n_embd_head_k;
1650+
const int64_t n_embd_kv = n_embd_head * n_head_kv;
1651+
1652+
// KV cache element size (typically f16 = 2 bytes, but can be quantized)
1653+
const size_t kv_size_element = 2; // sizeof(ggml_fp16_t)
1654+
1655+
// Total KV cache size for all layers (K and V)
1656+
// KV cache = 2 (K+V) * n_ctx * n_layers * n_embd_kv * element_size
1657+
const size_t kv_cache_size_total = 2ULL * n_ctx_for_kv * n_layer * n_embd_kv * kv_size_element;
1658+
1659+
// Estimate model weight size per layer
1660+
const size_t model_size_total = ml.n_bytes;
1661+
const size_t weight_size_per_layer = model_size_total / n_layer;
1662+
1663+
// Calculate actual compute buffer size based on attention matrix requirements
1664+
// Attention matrix: n_kv × n_ubatch × n_head × sizeof(float)
1665+
// This is the dominant memory consumer during inference
1666+
const int64_t n_head = hparams.n_head();
1667+
const size_t n_ubatch = 512; // Default physical batch size (from context params)
1668+
const size_t compute_buffer_size = n_ctx_for_kv * n_ubatch * n_head * sizeof(float);
1669+
const size_t min_overhead = 512ULL * 1024 * 1024; // 512MB base overhead
1670+
1671+
LLAMA_LOG_INFO("%s: Compute buffer size: %.2f MB (context=%u, ubatch=%zu, heads=%lld)\n",
1672+
__func__,
1673+
compute_buffer_size / 1024.0 / 1024.0,
1674+
n_ctx_for_kv, n_ubatch, (long long)n_head);
1675+
1676+
// For memory calculation, we need to account for KV cache being shared across layers on each device
1677+
// We'll calculate this dynamically during layer assignment
1678+
1679+
LLAMA_LOG_INFO("%s: Per-layer memory: weights=%.2f MB\n",
1680+
__func__,
1681+
weight_size_per_layer / 1024.0 / 1024.0);
1682+
LLAMA_LOG_INFO("%s: Total KV cache size: %.2f MB\n",
1683+
__func__,
1684+
kv_cache_size_total / 1024.0 / 1024.0);
1685+
1686+
// Get memory info and calculate layer assignments
1687+
std::vector<int> layers_per_gpu(n_devices(), 0);
1688+
std::vector<size_t> gpu_free_memory(n_devices());
1689+
1690+
// Get free memory for each device and check if they can handle compute buffers
1691+
std::vector<bool> device_excluded(n_devices(), false);
1692+
for (size_t i = 0; i < n_devices(); ++i) {
1693+
ggml_backend_dev_t dev = devices[i];
1694+
size_t total, free;
1695+
ggml_backend_dev_memory(dev, &free, &total);
1696+
gpu_free_memory[i] = free;
1697+
1698+
// Check if device can handle minimum requirements (1 layer + compute buffer + KV cache)
1699+
size_t min_kv_cache = kv_cache_size_total / n_devices(); // Conservative estimate
1700+
size_t min_required = weight_size_per_layer + min_kv_cache + compute_buffer_size + min_overhead;
1701+
1702+
if (free < min_required) {
1703+
device_excluded[i] = true;
1704+
LLAMA_LOG_WARN("%s: Device %zu [%s]: %.2f MB free - excluding (needs %.2f MB minimum)\n",
1705+
__func__, i, ggml_backend_dev_name(dev),
1706+
free / 1024.0 / 1024.0, min_required / 1024.0 / 1024.0);
1707+
}
1708+
}
1709+
1710+
// Estimate total memory requirements and warn if insufficient
1711+
size_t total_gpu_memory = 0;
1712+
for (size_t i = 0; i < n_devices(); ++i) {
1713+
total_gpu_memory += gpu_free_memory[i];
1714+
}
1715+
1716+
// Rough estimate: KV cache + model weights + compute buffers (conservative estimate)
1717+
size_t estimated_compute_buffers = kv_cache_size_total; // Compute buffers often similar to KV cache size
1718+
size_t estimated_total_needed = kv_cache_size_total + model_size_total + estimated_compute_buffers;
1719+
1720+
if (estimated_total_needed > total_gpu_memory) {
1721+
LLAMA_LOG_WARN("%s: Memory estimate: %.2f GB needed vs %.2f GB available\n",
1722+
__func__,
1723+
estimated_total_needed / 1024.0 / 1024.0 / 1024.0,
1724+
total_gpu_memory / 1024.0 / 1024.0 / 1024.0);
1725+
LLAMA_LOG_WARN("%s: Context size may be too large for available memory\n", __func__);
1726+
}
1727+
1728+
// Sort devices by available memory (largest first), excluding unusable devices
1729+
std::vector<size_t> gpu_indices;
1730+
for (size_t i = 0; i < n_devices(); ++i) {
1731+
if (!device_excluded[i]) {
1732+
gpu_indices.push_back(i);
1733+
}
1734+
}
1735+
std::sort(gpu_indices.begin(), gpu_indices.end(),
1736+
[&gpu_free_memory](size_t a, size_t b) {
1737+
return gpu_free_memory[a] > gpu_free_memory[b];
1738+
});
1739+
1740+
if (gpu_indices.empty()) {
1741+
LLAMA_LOG_ERROR("%s: No GPUs have sufficient memory for compute buffers\n", __func__);
1742+
// Fall back to original allocation
1743+
return true;
1744+
}
1745+
1746+
// Assign layers greedily to GPUs with most memory first
1747+
int act_gpu_layers = n_gpu_layers; // Local copy that can be modified
1748+
int remaining_layers = act_gpu_layers;
1749+
1750+
// First pass: assign layers based on weights only (KV cache and compute buffers handled separately)
1751+
size_t weight_per_layer = weight_size_per_layer;
1752+
1753+
for (size_t idx : gpu_indices) {
1754+
// Reserve memory for compute buffer and base overhead
1755+
size_t reserved = compute_buffer_size + min_overhead;
1756+
if (gpu_free_memory[idx] <= reserved) {
1757+
LLAMA_LOG_WARN("%s: Device %zu [%s]: %zu MB free, can't fit compute buffer (%.2f MB)\n",
1758+
__func__, idx, ggml_backend_dev_name(devices[idx]),
1759+
gpu_free_memory[idx] / 1024 / 1024,
1760+
reserved / 1024.0 / 1024.0);
1761+
continue;
1762+
}
1763+
1764+
size_t available_for_model = gpu_free_memory[idx] - reserved;
1765+
int layers_that_fit = available_for_model / weight_per_layer;
1766+
1767+
if (layers_that_fit > 0 && remaining_layers > 0) {
1768+
int layers_to_assign = std::min(layers_that_fit, remaining_layers);
1769+
layers_per_gpu[idx] = layers_to_assign;
1770+
remaining_layers -= layers_to_assign;
1771+
1772+
LLAMA_LOG_INFO("%s: Device %zu [%s]: %zu MB free, assigned %d layers (%.2f MB weights, %.2f MB compute buffer)\n",
1773+
__func__, idx, ggml_backend_dev_name(devices[idx]),
1774+
gpu_free_memory[idx] / 1024 / 1024,
1775+
layers_per_gpu[idx],
1776+
(layers_to_assign * weight_per_layer) / 1024.0 / 1024.0,
1777+
compute_buffer_size / 1024.0 / 1024.0);
1778+
} else {
1779+
LLAMA_LOG_WARN("%s: Device %zu [%s]: %zu MB free, assigned 0 layers (need %.2f MB per layer + %.2f MB compute buffer)\n",
1780+
__func__, idx, ggml_backend_dev_name(devices[idx]),
1781+
gpu_free_memory[idx] / 1024 / 1024,
1782+
weight_per_layer / 1024.0 / 1024.0,
1783+
compute_buffer_size / 1024.0 / 1024.0);
1784+
}
1785+
}
1786+
1787+
// Second pass: iteratively check if KV cache can fit proportionally
1788+
bool kv_fit_check_needed = (remaining_layers == 0);
1789+
int iterations = 0;
1790+
const int max_iterations = 10;
1791+
1792+
while (kv_fit_check_needed && iterations < max_iterations) {
1793+
kv_fit_check_needed = false;
1794+
iterations++;
1795+
1796+
// Calculate current total assigned layers
1797+
int total_assigned = 0;
1798+
for (size_t idx = 0; idx < n_devices(); ++idx) {
1799+
total_assigned += layers_per_gpu[idx];
1800+
}
1801+
1802+
if (total_assigned == 0) break;
1803+
1804+
// Check KV cache distribution for each device
1805+
for (size_t idx = 0; idx < n_devices(); ++idx) {
1806+
if (layers_per_gpu[idx] > 0) {
1807+
double layer_ratio = (double)layers_per_gpu[idx] / total_assigned;
1808+
size_t kv_cache_for_device = (size_t)(kv_cache_size_total * layer_ratio);
1809+
size_t weights = layers_per_gpu[idx] * weight_per_layer;
1810+
size_t total_memory_needed = weights + kv_cache_for_device + compute_buffer_size + min_overhead;
1811+
1812+
if (total_memory_needed > gpu_free_memory[idx]) {
1813+
// Device can't fit current allocation, reduce layers
1814+
size_t available_memory = gpu_free_memory[idx];
1815+
if (available_memory > min_overhead + kv_cache_for_device + compute_buffer_size) {
1816+
size_t available_for_weights = available_memory - min_overhead - kv_cache_for_device - compute_buffer_size;
1817+
int new_layer_count = available_for_weights / weight_per_layer;
1818+
new_layer_count = std::max(0, new_layer_count);
1819+
1820+
if (new_layer_count < layers_per_gpu[idx]) {
1821+
LLAMA_LOG_WARN("%s: Device %zu: Reducing layers from %d to %d due to KV cache requirements (%.2f MB KV cache)\n",
1822+
__func__, idx, layers_per_gpu[idx], new_layer_count,
1823+
kv_cache_for_device / 1024.0 / 1024.0);
1824+
remaining_layers += layers_per_gpu[idx] - new_layer_count;
1825+
layers_per_gpu[idx] = new_layer_count;
1826+
kv_fit_check_needed = true;
1827+
}
1828+
} else {
1829+
// Device can't even fit the minimum requirements
1830+
LLAMA_LOG_WARN("%s: Device %zu: Removing all %d layers - insufficient memory for KV cache\n",
1831+
__func__, idx, layers_per_gpu[idx]);
1832+
remaining_layers += layers_per_gpu[idx];
1833+
layers_per_gpu[idx] = 0;
1834+
kv_fit_check_needed = true;
1835+
}
1836+
}
1837+
}
1838+
}
1839+
}
1840+
1841+
// Third pass: redistribute any remaining layers to devices with available capacity
1842+
if (remaining_layers > 0) {
1843+
LLAMA_LOG_INFO("%s: Attempting to redistribute %d remaining layers\n", __func__, remaining_layers);
1844+
1845+
// Calculate current memory usage for each device that has layers assigned
1846+
for (size_t idx : gpu_indices) {
1847+
if (layers_per_gpu[idx] > 0 && remaining_layers > 0) {
1848+
// Calculate current memory usage
1849+
int current_assigned = 0;
1850+
for (size_t i = 0; i < n_devices(); ++i) {
1851+
current_assigned += layers_per_gpu[i];
1852+
}
1853+
1854+
double layer_ratio = (double)layers_per_gpu[idx] / current_assigned;
1855+
size_t current_kv_cache = (size_t)(kv_cache_size_total * layer_ratio);
1856+
size_t current_weights = layers_per_gpu[idx] * weight_per_layer;
1857+
size_t current_usage = current_weights + current_kv_cache + compute_buffer_size + min_overhead;
1858+
1859+
if (gpu_free_memory[idx] > current_usage) {
1860+
// Calculate how many additional layers could fit
1861+
// We need to account for proportional increase in KV cache
1862+
int additional_layers = 0;
1863+
for (int test_layers = 1; test_layers <= remaining_layers; test_layers++) {
1864+
int new_total_layers = layers_per_gpu[idx] + test_layers;
1865+
int new_total_assigned = current_assigned + test_layers;
1866+
double new_layer_ratio = (double)new_total_layers / new_total_assigned;
1867+
size_t new_kv_cache = (size_t)(kv_cache_size_total * new_layer_ratio);
1868+
size_t new_weights = new_total_layers * weight_per_layer;
1869+
size_t new_total_usage = new_weights + new_kv_cache + compute_buffer_size + min_overhead;
1870+
1871+
if (new_total_usage <= gpu_free_memory[idx]) {
1872+
additional_layers = test_layers;
1873+
} else {
1874+
break;
1875+
}
1876+
}
1877+
1878+
if (additional_layers > 0) {
1879+
int layers_to_add = std::min(additional_layers, remaining_layers);
1880+
layers_per_gpu[idx] += layers_to_add;
1881+
remaining_layers -= layers_to_add;
1882+
1883+
LLAMA_LOG_INFO("%s: Device %zu [%s]: redistributed %d additional layers (total now %d)\n",
1884+
__func__, idx, ggml_backend_dev_name(devices[idx]),
1885+
layers_to_add, layers_per_gpu[idx]);
1886+
}
1887+
}
1888+
}
1889+
}
1890+
}
1891+
1892+
// Warn if we couldn't place all layers
1893+
if (remaining_layers > 0) {
1894+
LLAMA_LOG_ERROR("%s: WARNING: Could not assign %d layers to GPUs. Consider:\n",
1895+
__func__, remaining_layers);
1896+
LLAMA_LOG_ERROR("%s: - Reducing context size (current: %u)\n",
1897+
__func__, n_ctx_for_kv);
1898+
LLAMA_LOG_ERROR("%s: - Using fewer layers (-ngl)\n", __func__);
1899+
LLAMA_LOG_ERROR("%s: - Adding more GPU memory\n", __func__);
1900+
1901+
// Put remaining layers on CPU (will be updated below)
1902+
}
1903+
1904+
// Convert layer counts to split ratios
1905+
splits.clear();
1906+
splits.resize(n_devices());
1907+
float cumsum = 0.0f;
1908+
1909+
// Calculate total layers actually assigned
1910+
int total_assigned_layers = 0;
1911+
for (size_t i = 0; i < n_devices(); ++i) {
1912+
total_assigned_layers += layers_per_gpu[i];
1913+
}
1914+
1915+
// Update act_gpu_layers to match what we actually assigned
1916+
act_gpu_layers = total_assigned_layers;
1917+
1918+
for (size_t i = 0; i < n_devices(); ++i) {
1919+
cumsum += (float)layers_per_gpu[i] / act_gpu_layers;
1920+
splits[i] = cumsum;
1921+
}
1922+
1923+
LLAMA_LOG_INFO("%s: Final split ratios: ", __func__);
1924+
for (size_t i = 0; i < n_devices(); ++i) {
1925+
LLAMA_LOG_CONT("%.3f ", splits[i]);
1926+
}
1927+
LLAMA_LOG_CONT("\n");
1928+
}
1929+
}
1930+
16241931
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
16251932
int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
16261933

@@ -15167,6 +15474,7 @@ llama_model_params llama_model_default_params() {
1516715474
/*.progress_callback =*/ nullptr,
1516815475
/*.progress_callback_user_data =*/ nullptr,
1516915476
/*.kv_overrides =*/ nullptr,
15477+
/*.requested_n_ctx =*/ 0,
1517015478
/*.vocab_only =*/ false,
1517115479
/*.use_mmap =*/ true,
1517215480
/*.use_mlock =*/ false,

0 commit comments

Comments
 (0)