Skip to content

Commit 257ad51

Browse files
authored
Bring back automaton minimization (#119309)
The security codebase relies heavily on automata and caching these. The Lucene 10 upgrade removed automaton minimization which can result in a memory usage increase of >5x, esp. for roles with many application privileges. This PR brings back Automaton minimization to avoid the explosion in roles cache size. Relates: ES-10451
1 parent 73f0a5e commit 257ad51

File tree

11 files changed

+429
-32
lines changed

11 files changed

+429
-32
lines changed

server/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,4 +480,5 @@
480480
exports org.elasticsearch.inference.configuration;
481481
exports org.elasticsearch.monitor.metrics;
482482
exports org.elasticsearch.plugins.internal.rewriter to org.elasticsearch.inference;
483+
exports org.elasticsearch.lucene.util.automaton;
483484
}
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.lucene.util.automaton;
11+
12+
import org.apache.lucene.internal.hppc.IntArrayList;
13+
import org.apache.lucene.internal.hppc.IntCursor;
14+
import org.apache.lucene.internal.hppc.IntHashSet;
15+
import org.apache.lucene.util.automaton.Automaton;
16+
import org.apache.lucene.util.automaton.Operations;
17+
import org.apache.lucene.util.automaton.Transition;
18+
19+
import java.util.BitSet;
20+
import java.util.LinkedList;
21+
22+
/**
23+
* Operations for minimizing automata.
24+
* <p>
25+
* Lucene 10 removed minimization, but Elasticsearch still requires it.
26+
* Minimization is critical in the security codebase to reduce the heap
27+
* usage of automata used for permission checks.
28+
* <p>
29+
* Copied of Lucene's AutomatonTestUtil
30+
*/
31+
public final class MinimizationOperations {
32+
33+
private MinimizationOperations() {}
34+
35+
/**
36+
* Minimizes (and determinizes if not already deterministic) the given automaton using Hopcroft's
37+
* algorithm.
38+
*
39+
* @param determinizeWorkLimit maximum effort to spend determinizing the automaton. Set higher to
40+
* allow more complex queries and lower to prevent memory exhaustion. Use {@link
41+
* Operations#DEFAULT_DETERMINIZE_WORK_LIMIT} as a decent default if you don't otherwise know
42+
* what to specify.
43+
*/
44+
public static Automaton minimize(Automaton a, int determinizeWorkLimit) {
45+
46+
if (a.getNumStates() == 0 || (a.isAccept(0) == false && a.getNumTransitions(0) == 0)) {
47+
// Fastmatch for common case
48+
return new Automaton();
49+
}
50+
a = Operations.determinize(a, determinizeWorkLimit);
51+
// a.writeDot("adet");
52+
if (a.getNumTransitions(0) == 1) {
53+
Transition t = new Transition();
54+
a.getTransition(0, 0, t);
55+
if (t.dest == 0 && t.min == Character.MIN_CODE_POINT && t.max == Character.MAX_CODE_POINT) {
56+
// Accepts all strings
57+
return a;
58+
}
59+
}
60+
a = totalize(a);
61+
// a.writeDot("atot");
62+
63+
// initialize data structures
64+
final int[] sigma = a.getStartPoints();
65+
final int sigmaLen = sigma.length, statesLen = a.getNumStates();
66+
67+
final IntArrayList[][] reverse = new IntArrayList[statesLen][sigmaLen];
68+
final IntHashSet[] partition = new IntHashSet[statesLen];
69+
final IntArrayList[] splitblock = new IntArrayList[statesLen];
70+
final int[] block = new int[statesLen];
71+
final StateList[][] active = new StateList[statesLen][sigmaLen];
72+
final StateListNode[][] active2 = new StateListNode[statesLen][sigmaLen];
73+
final LinkedList<IntPair> pending = new LinkedList<>();
74+
final BitSet pending2 = new BitSet(sigmaLen * statesLen);
75+
final BitSet split = new BitSet(statesLen), refine = new BitSet(statesLen), refine2 = new BitSet(statesLen);
76+
for (int q = 0; q < statesLen; q++) {
77+
splitblock[q] = new IntArrayList();
78+
partition[q] = new IntHashSet();
79+
for (int x = 0; x < sigmaLen; x++) {
80+
active[q][x] = StateList.EMPTY;
81+
}
82+
}
83+
// find initial partition and reverse edges
84+
for (int q = 0; q < statesLen; q++) {
85+
// TODO moved the following into the loop because we cannot reset transition.transitionUpto (pkg private)
86+
Transition transition = new Transition();
87+
final int j = a.isAccept(q) ? 0 : 1;
88+
partition[j].add(q);
89+
block[q] = j;
90+
transition.source = q;
91+
// TODO we'd need to be able to access pkg private transition.transitionUpto if we want to optimize the following
92+
// transition.transitionUpto = -1;
93+
for (int x = 0; x < sigmaLen; x++) {
94+
final IntArrayList[] r = reverse[a.next(transition, sigma[x])];
95+
if (r[x] == null) {
96+
r[x] = new IntArrayList();
97+
}
98+
r[x].add(q);
99+
}
100+
}
101+
// initialize active sets
102+
for (int j = 0; j <= 1; j++) {
103+
for (int x = 0; x < sigmaLen; x++) {
104+
for (IntCursor qCursor : partition[j]) {
105+
int q = qCursor.value;
106+
if (reverse[q][x] != null) {
107+
StateList stateList = active[j][x];
108+
if (stateList == StateList.EMPTY) {
109+
stateList = new StateList();
110+
active[j][x] = stateList;
111+
}
112+
active2[q][x] = stateList.add(q);
113+
}
114+
}
115+
}
116+
}
117+
118+
// initialize pending
119+
for (int x = 0; x < sigmaLen; x++) {
120+
final int j = (active[0][x].size <= active[1][x].size) ? 0 : 1;
121+
pending.add(new IntPair(j, x));
122+
pending2.set(x * statesLen + j);
123+
}
124+
125+
// process pending until fixed point
126+
int k = 2;
127+
// System.out.println("start min");
128+
while (false == pending.isEmpty()) {
129+
// System.out.println(" cycle pending");
130+
final IntPair ip = pending.removeFirst();
131+
final int p = ip.n1;
132+
final int x = ip.n2;
133+
// System.out.println(" pop n1=" + ip.n1 + " n2=" + ip.n2);
134+
pending2.clear(x * statesLen + p);
135+
// find states that need to be split off their blocks
136+
for (StateListNode m = active[p][x].first; m != null; m = m.next) {
137+
final IntArrayList r = reverse[m.q][x];
138+
if (r != null) {
139+
for (IntCursor iCursor : r) {
140+
final int i = iCursor.value;
141+
if (false == split.get(i)) {
142+
split.set(i);
143+
final int j = block[i];
144+
splitblock[j].add(i);
145+
if (false == refine2.get(j)) {
146+
refine2.set(j);
147+
refine.set(j);
148+
}
149+
}
150+
}
151+
}
152+
}
153+
154+
// refine blocks
155+
for (int j = refine.nextSetBit(0); j >= 0; j = refine.nextSetBit(j + 1)) {
156+
final IntArrayList sb = splitblock[j];
157+
if (sb.size() < partition[j].size()) {
158+
final IntHashSet b1 = partition[j];
159+
final IntHashSet b2 = partition[k];
160+
for (IntCursor iCursor : sb) {
161+
final int s = iCursor.value;
162+
b1.remove(s);
163+
b2.add(s);
164+
block[s] = k;
165+
for (int c = 0; c < sigmaLen; c++) {
166+
final StateListNode sn = active2[s][c];
167+
if (sn != null && sn.sl == active[j][c]) {
168+
sn.remove();
169+
StateList stateList = active[k][c];
170+
if (stateList == StateList.EMPTY) {
171+
stateList = new StateList();
172+
active[k][c] = stateList;
173+
}
174+
active2[s][c] = stateList.add(s);
175+
}
176+
}
177+
}
178+
// update pending
179+
for (int c = 0; c < sigmaLen; c++) {
180+
final int aj = active[j][c].size, ak = active[k][c].size, ofs = c * statesLen;
181+
if ((false == pending2.get(ofs + j)) && 0 < aj && aj <= ak) {
182+
pending2.set(ofs + j);
183+
pending.add(new IntPair(j, c));
184+
} else {
185+
pending2.set(ofs + k);
186+
pending.add(new IntPair(k, c));
187+
}
188+
}
189+
k++;
190+
}
191+
refine2.clear(j);
192+
for (IntCursor iCursor : sb) {
193+
final int s = iCursor.value;
194+
split.clear(s);
195+
}
196+
sb.clear();
197+
}
198+
refine.clear();
199+
}
200+
201+
Automaton result = new Automaton();
202+
203+
Transition t = new Transition();
204+
205+
// System.out.println(" k=" + k);
206+
207+
// make a new state for each equivalence class, set initial state
208+
int[] stateMap = new int[statesLen];
209+
int[] stateRep = new int[k];
210+
211+
result.createState();
212+
213+
// System.out.println("min: k=" + k);
214+
for (int n = 0; n < k; n++) {
215+
// System.out.println(" n=" + n);
216+
217+
boolean isInitial = partition[n].contains(0);
218+
219+
int newState;
220+
if (isInitial) {
221+
// System.out.println(" isInitial!");
222+
newState = 0;
223+
} else {
224+
newState = result.createState();
225+
}
226+
227+
// System.out.println(" newState=" + newState);
228+
229+
for (IntCursor qCursor : partition[n]) {
230+
int q = qCursor.value;
231+
stateMap[q] = newState;
232+
// System.out.println(" q=" + q + " isAccept?=" + a.isAccept(q));
233+
result.setAccept(newState, a.isAccept(q));
234+
stateRep[newState] = q; // select representative
235+
}
236+
}
237+
238+
// build transitions and set acceptance
239+
for (int n = 0; n < k; n++) {
240+
int numTransitions = a.initTransition(stateRep[n], t);
241+
for (int i = 0; i < numTransitions; i++) {
242+
a.getNextTransition(t);
243+
// System.out.println(" add trans");
244+
result.addTransition(n, stateMap[t.dest], t.min, t.max);
245+
}
246+
}
247+
result.finishState();
248+
// System.out.println(result.getNumStates() + " states");
249+
250+
return Operations.removeDeadStates(result);
251+
}
252+
253+
record IntPair(int n1, int n2) {}
254+
255+
static final class StateList {
256+
257+
// Empty list that should never be mutated, used as a memory saving optimization instead of null
258+
// so we don't need to branch the read path in #minimize
259+
static final StateList EMPTY = new StateList();
260+
261+
int size;
262+
263+
StateListNode first, last;
264+
265+
StateListNode add(int q) {
266+
assert this != EMPTY;
267+
return new StateListNode(q, this);
268+
}
269+
}
270+
271+
static final class StateListNode {
272+
273+
final int q;
274+
275+
StateListNode next, prev;
276+
277+
final StateList sl;
278+
279+
StateListNode(int q, StateList sl) {
280+
this.q = q;
281+
this.sl = sl;
282+
if (sl.size++ == 0) sl.first = sl.last = this;
283+
else {
284+
sl.last.next = this;
285+
prev = sl.last;
286+
sl.last = this;
287+
}
288+
}
289+
290+
void remove() {
291+
sl.size--;
292+
if (sl.first == this) sl.first = next;
293+
else prev.next = next;
294+
if (sl.last == this) sl.last = prev;
295+
else next.prev = prev;
296+
}
297+
}
298+
299+
static Automaton totalize(Automaton a) {
300+
Automaton result = new Automaton();
301+
int numStates = a.getNumStates();
302+
for (int i = 0; i < numStates; i++) {
303+
result.createState();
304+
result.setAccept(i, a.isAccept(i));
305+
}
306+
307+
int deadState = result.createState();
308+
result.addTransition(deadState, deadState, Character.MIN_CODE_POINT, Character.MAX_CODE_POINT);
309+
310+
Transition t = new Transition();
311+
for (int i = 0; i < numStates; i++) {
312+
int maxi = Character.MIN_CODE_POINT;
313+
int count = a.initTransition(i, t);
314+
for (int j = 0; j < count; j++) {
315+
a.getNextTransition(t);
316+
result.addTransition(i, t.dest, t.min, t.max);
317+
if (t.min > maxi) {
318+
result.addTransition(i, deadState, maxi, t.min - 1);
319+
}
320+
if (t.max + 1 > maxi) {
321+
maxi = t.max + 1;
322+
}
323+
}
324+
325+
if (maxi <= Character.MAX_CODE_POINT) {
326+
result.addTransition(i, deadState, maxi, Character.MAX_CODE_POINT);
327+
}
328+
}
329+
330+
result.finishState();
331+
return result;
332+
}
333+
}

0 commit comments

Comments
 (0)