Skip to content

Commit 9a1656e

Browse files
committed
Refactor pareto optimise and convexify
1 parent 1a3e9ea commit 9a1656e

File tree

1 file changed

+41
-43
lines changed

1 file changed

+41
-43
lines changed

src/llama-quant.cpp

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,55 +1179,53 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
11791179
}
11801180

11811181
// Keep only the pareto‑optimal candidates and enforce convexity in (bytes, error) curve
1182-
{
1183-
auto & candidates = info.candidate;
1184-
if (!candidates.empty()) {
1185-
std::sort(candidates.begin(), candidates.end(), [](const candidate_types & a, const candidate_types & b) {
1186-
if (a.bytes != b.bytes) { return a.bytes < b.bytes; }
1187-
1188-
return a.error < b.error;
1189-
});
1190-
1191-
std::vector<candidate_types> pareto;
1192-
pareto.reserve(candidates.size());
1193-
double best_err = infinity;
1194-
size_t last_bytes = std::numeric_limits<size_t>::max();
1195-
for (const auto & c : candidates) {
1196-
if (c.bytes != last_bytes) {
1197-
last_bytes = c.bytes;
1198-
if (c.error < best_err) {
1199-
best_err = c.error;
1200-
pareto.push_back(c);
1201-
}
1202-
}
1203-
}
1182+
auto pareto_convex = [](std::vector<candidate_types> & candidates) {
1183+
if (candidates.empty()) return;
12041184

1205-
candidates.swap(pareto);
1206-
1207-
if (candidates.size() >= 3) {
1208-
std::vector<candidate_types> hull;
1209-
hull.reserve(candidates.size());
1210-
auto slope = [](const candidate_types & a, const candidate_types & b) {
1211-
const double dx = b.bytes - a.bytes;
1212-
1213-
return dx <= 0.0 ? infinity : (b.error - a.error) / dx;
1214-
};
1215-
1216-
for (const auto & p : candidates) {
1217-
while (hull.size() >= 2) {
1218-
double s1 = slope(hull[hull.size() - 2], hull[hull.size() - 1]);
1219-
double s2 = slope(hull[hull.size() - 1], p);
1220-
if (s2 + epsilon < s1) { hull.pop_back(); }
1221-
else { break; }
1222-
}
1185+
std::sort(candidates.begin(), candidates.end(), [](const candidate_types & a, const candidate_types & b) {
1186+
if (a.bytes != b.bytes) { return a.bytes < b.bytes; }
1187+
return a.error < b.error;
1188+
});
12231189

1224-
hull.push_back(p);
1190+
// Pareto by bytes -> error
1191+
std::vector<candidate_types> pareto;
1192+
pareto.reserve(candidates.size());
1193+
double best_err = std::numeric_limits<double>::infinity();
1194+
size_t last_b = std::numeric_limits<size_t>::max();
1195+
for (const auto & c : candidates) {
1196+
if (c.bytes != last_b) {
1197+
last_b = c.bytes;
1198+
if (c.error < best_err) {
1199+
best_err = c.error;
1200+
pareto.push_back(c);
12251201
}
1202+
}
1203+
}
12261204

1227-
candidates.swap(hull);
1205+
candidates.swap(pareto);
1206+
if (candidates.size() < 3) { return; } // need at least 3 points to do convex hull
1207+
1208+
// Convex hull (lower envelope)
1209+
auto slope = [](const candidate_types & a, const candidate_types & b) {
1210+
const double dx = b.bytes - a.bytes;
1211+
return dx <= 0.0 ? infinity : (b.error - a.error) / dx;
1212+
};
1213+
1214+
std::vector<candidate_types> hull; hull.reserve(candidates.size());
1215+
for (const auto & p : candidates) {
1216+
while (hull.size() >= 2) {
1217+
const double s1 = slope(hull[hull.size() - 2], hull[hull.size() - 1]);
1218+
const double s2 = slope(hull[hull.size() - 1], p);
1219+
if (s2 + epsilon < s1) hull.pop_back();
1220+
else { break; }
12281221
}
1222+
1223+
hull.push_back(p);
12291224
}
1230-
}
1225+
candidates.swap(hull);
1226+
};
1227+
1228+
pareto_convex(info.candidate);
12311229

12321230
// Initialize choice at the smallest bpw candidate
12331231
info.choice = 0;

0 commit comments

Comments
 (0)