Skip to content

Commit 9b8eace

Browse files
author
Killian Perlin
committed
Add support for class hierarchy
Class hierarchy is taken into account when checking the type of a dynamic node instead of verifying for strict equality.
1 parent a394748 commit 9b8eace

File tree

6 files changed

+300
-4
lines changed

6 files changed

+300
-4
lines changed

lkql_jit/language/src/main/java/com/adacore/lkql_jit/LKQLContext.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import com.adacore.lkql_jit.checker.UnitChecker;
1414
import com.adacore.lkql_jit.checker.utils.CheckerUtils;
1515
import com.adacore.lkql_jit.exception.LKQLRuntimeException;
16+
import com.adacore.lkql_jit.langkit_translator.passes.Hierarchy;
1617
import com.adacore.lkql_jit.options.LKQLOptions;
1718
import com.adacore.lkql_jit.options.RuleInstance;
1819
import com.adacore.lkql_jit.runtime.GlobalScope;
@@ -133,6 +134,11 @@ public final class LKQLContext {
133134
@CompilerDirectives.CompilationFinal
134135
private CheckerUtils.DiagnosticEmitter emitter;
135136

137+
// ----- Nanopass typing context -----
138+
139+
/** Typing context used by pattern-matching during a rewriting pass */
140+
private Hierarchy typingContext = null;
141+
136142
// ----- Constructors -----
137143

138144
/**
@@ -218,6 +224,10 @@ public LangkitSupport.RewritingApplyResult applyOrCloseRewritingContext() {
218224
return res;
219225
}
220226

227+
public Hierarchy getTypingContext() {
228+
return typingContext;
229+
}
230+
221231
// ----- Setters -----
222232

223233
public void patchContext(TruffleLanguage.Env newEnv) {
@@ -227,6 +237,10 @@ public void patchContext(TruffleLanguage.Env newEnv) {
227237
this.initSources();
228238
}
229239

240+
public void setTypingContext(Hierarchy typingContext) {
241+
this.typingContext = typingContext;
242+
}
243+
230244
// ----- Options getting methods -----
231245

232246
/** Parse the LKQL engine options passed as a JSON string, store it in a cache and return it. */
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
//
2+
// Copyright (C) 2005-2025, AdaCore
3+
// SPDX-License-Identifier: GPL-3.0-or-later
4+
//
5+
6+
package com.adacore.lkql_jit.langkit_translator.passes;
7+
8+
import com.adacore.libadalang.Libadalang;
9+
import com.adacore.libadalang.Libadalang.AdaNode;
10+
import java.util.*;
11+
12+
/**
13+
14+
This class is a representation of a totally ordered set of elements S.
15+
It is used to store the class hierarchy of arbitrarily named classes.
16+
17+
Each class X can have one super-class Y, which we note X < Y.
18+
Because this table is used to query `instanceof` operations, we store
19+
the relation `≤` rather than `<`.
20+
21+
The relations are stored in a matrix of size |S|².
22+
Example:
23+
24+
The hierarchy
25+
26+
<pre>.
27+
A
28+
/ \
29+
B C
30+
/ \
31+
D E
32+
</pre>
33+
34+
with the index [A=0, B=1, C=2, D=3, E=4]
35+
36+
is stored in this table (an `x` represents a `≤` relation)
37+
38+
<pre>.
39+
A B C D E
40+
A x
41+
B x x
42+
C x x
43+
D x x x
44+
E x x x
45+
</pre>
46+
47+
The matrix cannot be stored in a triangle, it must be a full square !
48+
To demonstrate this, consider the index [B=0, C=1, D=2, E=3, A=4]
49+
putting A to the end, the table becomes:
50+
51+
<pre>.
52+
B C D E A
53+
B x x
54+
C x x
55+
D x x x
56+
E x x x
57+
A x
58+
</pre>
59+
60+
Because the table is stored as a flat array
61+
the X ≤ Y relation is found at `index(X) * size + index(Y)`
62+
of the array.
63+
64+
*/
65+
public class Hierarchy {
66+
67+
private int classCount;
68+
private boolean[] inheritanceMatrix;
69+
private HashMap<String, Integer> classNamesToIndex;
70+
71+
private void become(int size, boolean[] matrix, HashMap<String, Integer> index) {
72+
this.classCount = size;
73+
this.inheritanceMatrix = matrix;
74+
this.classNamesToIndex = index;
75+
}
76+
77+
public static Hierarchy initial() {
78+
// Collect all inital classes
79+
final Class<? extends AdaNode>[] initialClasses = Libadalang.NODE_DESCRIPTION_MAP.values()
80+
.stream()
81+
.map(e -> e.clazz)
82+
.toArray(Class[]::new);
83+
84+
final int size = initialClasses.length;
85+
86+
// Compute X ≤ Y matrix
87+
final boolean[] matrix = new boolean[size * size];
88+
89+
for (int i = 0; i < size; i++) {
90+
for (int j = 0; j < size; j++) {
91+
matrix[i * size + j] = initialClasses[j].isAssignableFrom(initialClasses[i]);
92+
}
93+
}
94+
95+
// Assign an integer in 0..<N to each class
96+
final HashMap<String, Integer> index = new HashMap<>(size);
97+
for (int i = 0; i < size; i++) {
98+
index.put(initialClasses[i].getSimpleName(), i);
99+
}
100+
101+
final var res = new Hierarchy();
102+
res.become(size, matrix, index);
103+
return res;
104+
}
105+
106+
/**
107+
* The number of classes present in the hierarchy
108+
*/
109+
public int size() {
110+
return classCount;
111+
}
112+
113+
/**
114+
* Main query method for the hierarchy.
115+
* This is equivalent to `classY.isAssignableFrom(classX)`.
116+
* In ordered set notation this corresponds to X ≤ Y.
117+
*
118+
* @see java.lang.Class.isAssignableFrom(Class<?> cls)
119+
*/
120+
@TruffleBoundary
121+
public boolean isInstance(String classX, String classY) {
122+
return inheritanceMatrix[classNamesToIndex.get(classX) * size() +
123+
classNamesToIndex.get(classY)];
124+
}
125+
126+
public void add(String classX) {
127+
addAll(Collections.singleton(classX));
128+
}
129+
130+
public void addAll(Collection<String> classes) {
131+
final int size = this.size() + classes.size();
132+
final boolean[] matrix = new boolean[size * size];
133+
134+
// Copy original matrix
135+
for (int i = 0; i < this.size(); i++) {
136+
for (int j = 0; j < this.size(); j++) {
137+
matrix[i * size + j] = this.inheritanceMatrix[i * this.size() + j];
138+
}
139+
}
140+
141+
// Forall I . I ≤ I
142+
for (int i = 0; i < size; i++) {
143+
matrix[i * size + i] = true;
144+
}
145+
146+
final var index = new HashMap<>(this.classNamesToIndex);
147+
int i = this.size();
148+
for (var classI : classes) {
149+
index.put(classI, i++);
150+
}
151+
152+
become(size, matrix, index);
153+
}
154+
155+
/**
156+
* Add a new X ≤ Y relation to the hierarchy.
157+
*/
158+
@TruffleBoundary
159+
public void addInstanceOfRelation(String classX, String classY) {
160+
final int x = classNamesToIndex.get(classX);
161+
final int y = classNamesToIndex.get(classY);
162+
163+
inheritanceMatrix[x * size() + y] = true; // X ≤ Y
164+
165+
// forall I , I ≤ X => I ≤ Y
166+
for (int i = 0; i < size(); i++) {
167+
inheritanceMatrix[i * size() + y] |= inheritanceMatrix[i * size() + x];
168+
}
169+
// forall I , Y ≤ I => X ≤ I
170+
for (int i = 0; i < size(); i++) {
171+
inheritanceMatrix[x * size() + i] |= inheritanceMatrix[y * size() + i];
172+
}
173+
}
174+
175+
public void remove(String classX) {
176+
removeAll(Collections.singleton(classX));
177+
}
178+
179+
public void removeAll(Collection<String> classes) {
180+
// Compute all the subtypes of X to mask wich row/column to keep
181+
final boolean[] deleteMask = new boolean[size()];
182+
int deleteCount = 0;
183+
for (var classX : classes) {
184+
final int x = classNamesToIndex.get(classX);
185+
// forall I , I ≤ X => I in delete
186+
for (int i = 0; i < size(); i++) {
187+
if (inheritanceMatrix[i * size() + x] && !deleteMask[i]) {
188+
deleteMask[i] = true;
189+
deleteCount++;
190+
}
191+
}
192+
}
193+
194+
// Copy relevant rows/columns to new matrix
195+
final int size = this.size() - deleteCount;
196+
final boolean[] matrix = new boolean[size * size];
197+
198+
int iCursor = 0;
199+
for (int i = 0; i < this.size(); i++) {
200+
if (deleteMask[i]) continue;
201+
202+
int jCursor = 0;
203+
for (int j = 0; j < this.size(); j++) {
204+
if (deleteMask[j]) continue;
205+
206+
matrix[iCursor * size + jCursor] = this.inheritanceMatrix[i * this.size() + j];
207+
208+
jCursor++;
209+
}
210+
211+
iCursor++;
212+
}
213+
214+
// Adjust indexes by shifting them down
215+
final var reverseIndex = new String[this.size()];
216+
for (var entry : this.classNamesToIndex.entrySet()) {
217+
reverseIndex[entry.getValue()] = entry.getKey();
218+
}
219+
220+
final var index = new HashMap<String, Integer>(size);
221+
222+
int cursor = 0;
223+
for (int i = 0; i < this.size(); i++) {
224+
if (deleteMask[i]) continue;
225+
index.put(reverseIndex[i], cursor);
226+
cursor++;
227+
}
228+
229+
become(size, matrix, index);
230+
}
231+
232+
// Primarily used in testing
233+
public boolean equals(Object otherObject) {
234+
if (otherObject == null) return false;
235+
if (otherObject instanceof Hierarchy other) {
236+
if (size != other.size) return false;
237+
if (!index.keySet().equals(other.index.keySet())) return false;
238+
239+
for (var x : index.keySet()) {
240+
final var classX = index.get(x);
241+
final var otherX = other.index.get(x);
242+
for (var y : index.keySet()) {
243+
final var classY = index.get(y);
244+
final var otherY = other.index.get(y);
245+
if (
246+
matrix[classX * size + classY] != other.matrix[otherX * size + otherY]
247+
) return false;
248+
}
249+
}
250+
251+
return true;
252+
}
253+
return false;
254+
}
255+
}

lkql_jit/language/src/main/java/com/adacore/lkql_jit/nodes/pass/PassExpr.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package com.adacore.lkql_jit.nodes.pass;
77

8+
import com.adacore.lkql_jit.LKQLLanguage;
89
import com.adacore.lkql_jit.exception.LKQLRuntimeException;
910
import com.adacore.lkql_jit.nodes.expressions.Expr;
1011
import com.adacore.lkql_jit.nodes.expressions.value_read.ReadArgument;
@@ -57,7 +58,10 @@ protected PassExpr(
5758
// this can be removed if pattern matching exhaustivity is pre-checked
5859
@Specialization
5960
public Object onDynamicAdaNode(VirtualFrame frame, DynamicAdaNode input) {
61+
final var typingContext = LKQLLanguage.getContext(this).getTypingContext();
62+
typingContext.addAll(add.classes.stream().map(cd -> cd.name).toList());
6063
final var updatedTree = getUpdatedTree(input, frame);
64+
typingContext.removeAll(del.classes);
6165
return updatedTree;
6266
}
6367

lkql_jit/language/src/main/java/com/adacore/lkql_jit/nodes/pass/RunPass.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package com.adacore.lkql_jit.nodes.pass;
77

88
import com.adacore.lkql_jit.LKQLLanguage;
9+
import com.adacore.lkql_jit.langkit_translator.passes.Hierarchy;
910
import com.adacore.lkql_jit.nodes.LKQLNode;
1011
import com.adacore.lkql_jit.runtime.values.AdaNodeProxy;
1112
import com.adacore.lkql_jit.runtime.values.LKQLFunction;
@@ -57,6 +58,8 @@ public Object executeGeneric(VirtualFrame frame) {
5758
units[i] = AdaNodeProxy.convertAST(roots[i]);
5859
}
5960

61+
LKQLLanguage.getContext(this).setTypingContext(Hierarchy.initial());
62+
6063
do {
6164
final var pass = callChain.pop();
6265

lkql_jit/language/src/main/java/com/adacore/lkql_jit/nodes/patterns/node_patterns/DynamicNodeKindPattern.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
package com.adacore.lkql_jit.nodes.patterns.node_patterns;
77

8-
import com.adacore.lkql_jit.runtime.values.DynamicAdaNode;
8+
import com.adacore.lkql_jit.LKQLLanguage;
9+
import com.adacore.lkql_jit.LKQLTypeSystemGen;
10+
import com.adacore.lkql_jit.runtime.values.AdaNodeProxy;
911
import com.oracle.truffle.api.frame.VirtualFrame;
1012
import com.oracle.truffle.api.source.SourceSection;
1113

@@ -41,9 +43,17 @@ public DynamicNodeKindPattern(SourceSection location, String kindName) {
4143
*/
4244
@Override
4345
public boolean executeValue(VirtualFrame frame, Object value) {
44-
if (value instanceof DynamicAdaNode node) {
45-
return node.kind.equals(kindName);
46-
} else return false;
46+
if (LKQLTypeSystemGen.isDynamicAdaNode(value)) {
47+
return LKQLLanguage.getContext(this)
48+
.getTypingContext()
49+
.isInstance(LKQLTypeSystemGen.asDynamicAdaNode(value).kind, kindName);
50+
} else if (LKQLTypeSystemGen.isNodeInterface(value)) {
51+
final var dynNode = AdaNodeProxy.convertAST(LKQLTypeSystemGen.asNodeInterface(value));
52+
return LKQLLanguage.getContext(this)
53+
.getTypingContext()
54+
.isInstance(dynNode.kind, kindName);
55+
}
56+
return false;
4757
}
4858

4959
// ----- Override methods -----

lkql_jit/language/src/main/java/com/adacore/lkql_jit/nodes/patterns/node_patterns/ExtendedNodePattern.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ public boolean executeValue(VirtualFrame frame, Object value) {
6565
if (!details[i].executeDetail(frame, node)) return false;
6666
}
6767

68+
// Return the success
69+
return true;
70+
} else if (LKQLTypeSystemGen.isDynamicAdaNode(value)) {
71+
var node = LKQLTypeSystemGen.asDynamicAdaNode(value);
72+
73+
// Verify all details
74+
for (NodePatternDetail detail : this.details) {
75+
if (!detail.executeDetail(frame, node)) return false;
76+
}
77+
6878
// Return the success
6979
return true;
7080
} else {

0 commit comments

Comments
 (0)