Skip to content

Commit 56517df

Browse files
authored
lcb: Implement lowering for conditional move and selects (#488)
* lcb: Added custom lowering for setcc * lcb: Lowering of SelectCC * lcb: Implemented lowering of Select for AArch64 * lcb: Added matchings of CSEL * chore: Make CheckStyle happy * lcb: Fixed SETCC
2 parents 61cd679 + 9298670 commit 56517df

File tree

7 files changed

+210
-47
lines changed

7 files changed

+210
-47
lines changed

vadl/main/resources/templates/lcb/llvm/lib/Target/ISelLowering.cpp

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,7 @@ void [(${namespace})]TargetLowering::anchor() {}
3939
setOperationAction(ISD::VAARG, MVT::Other, Custom);
4040
setOperationAction(ISD::VACOPY, MVT::Other, Expand);
4141
setOperationAction(ISD::VAEND, MVT::Other, Expand);
42-
[#th:block th:if="${!hasCMove32 && stackPointerBitWidth == 32}"]
43-
setOperationAction(ISD::SELECT, MVT::i32, Custom);
44-
[/th:block]
45-
[#th:block th:if="${!hasCMove64 && stackPointerBitWidth == 64}"]
46-
setOperationAction(ISD::SELECT, MVT::i64, Custom);
47-
[/th:block]
48-
setOperationAction(ISD::SELECT_CC, MVT::[(${stackPointerType})], Expand);
42+
setOperationAction(ISD::SELECT, MVT::[(${stackPointerType})], Custom);
4943
setOperationAction(ISD::SMUL_LOHI, MVT::i32, Expand);
5044
setOperationAction(ISD::UMUL_LOHI, MVT::i32, Expand);
5145
for (auto VT : {MVT::i1, MVT::i8, MVT::i16, MVT::i32}) {
@@ -67,6 +61,11 @@ void [(${namespace})]TargetLowering::anchor() {}
6761
[#th:block th:if="${mergedCmpAndBranch}"]
6862
setOperationAction(ISD::BRCOND, MVT::Other, Expand);
6963
setOperationAction(ISD::BR_CC, MVT::[(${stackPointerType})], Custom);
64+
setOperationAction(ISD::SELECT_CC, MVT::[(${stackPointerType})], Custom);
65+
setOperationAction(ISD::SETCC, MVT::[(${stackPointerType})], Expand);
66+
[/th:block]
67+
[#th:block th:if="${!mergedCmpAndBranch}"]
68+
setOperationAction(ISD::SELECT_CC, MVT::[(${stackPointerType})], Expand);
7069
[/th:block]
7170

7271
setBooleanContents(ZeroOrOneBooleanContent);
@@ -92,6 +91,20 @@ const char *[(${namespace})]TargetLowering::getTargetNodeName(unsigned Opcode) c
9291
}
9392
}
9493

94+
static SDValue lowerSetcc(SDValue Op, SelectionDAG &DAG)
95+
{
96+
Op->dump();
97+
SDValue LHS = Op.getOperand(0);
98+
SDValue RHS = Op.getOperand(1);
99+
SDLoc dl(Op);
100+
101+
EVT VT = Op.getValueType();
102+
SDValue TVal = DAG.getConstant(1, dl, VT);
103+
SDValue FVal = DAG.getConstant(0, dl, VT);
104+
105+
return DAG.getNode(ISD::SELECT_CC, dl, Op->getValueType(0), LHS, RHS, TVal, FVal, Op.getOperand(2));
106+
}
107+
95108
[#th:block th:if="${mergedCmpAndBranch}"]
96109
static SDValue lowerBR_CC(SDValue Op, SelectionDAG &DAG) {
97110
SDValue Chain = Op.getOperand(0);
@@ -101,7 +114,7 @@ static SDValue lowerBR_CC(SDValue Op, SelectionDAG &DAG) {
101114
SDValue Dest = Op.getOperand(4);
102115
SDLoc dl(Op);
103116

104-
auto Sub = DAG.getMachineNode([(${namespace})]::[(${SUBS})], dl, { MVT::i64, MVT::Other }, {LHS, RHS, Chain});
117+
auto Sub = DAG.getMachineNode([(${namespace})]::[(${SUBS})], dl, { MVT::[(${stackPointerType})], MVT::Other }, {LHS, RHS, Chain});
105118
SDValue ConditionFlag = SDValue(Sub, 1);
106119

107120
switch(CC) {
@@ -128,6 +141,46 @@ static SDValue lowerBR_CC(SDValue Op, SelectionDAG &DAG) {
128141
}
129142
}
130143

144+
static SDValue lowerSelectcc(SDValue Op, SelectionDAG &DAG)
145+
{
146+
Op->dump();
147+
SDValue LHS = Op.getOperand(0);
148+
SDValue RHS = Op.getOperand(1);
149+
auto TVal = Op.getOperand(2);
150+
auto FVal = Op.getOperand(3);
151+
152+
ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(4))->get();
153+
SDLoc dl(Op);
154+
155+
auto Sub = DAG.getMachineNode([(${namespace})]::[(${SUBS})], dl, { MVT::[(${stackPointerType})], MVT::Other }, {LHS, RHS});
156+
SDValue ConditionFlag = SDValue(Sub, 1);
157+
158+
switch(CC) {
159+
case ISD::CondCode::SETEQ:
160+
return SDValue(DAG.getMachineNode([(${namespace})]::[(${CSEL_EQ})], dl, MVT::[(${stackPointerType})], TVal, FVal, ConditionFlag ), 0);
161+
break;
162+
case ISD::CondCode::SETNE:
163+
return SDValue(DAG.getMachineNode([(${namespace})]::[(${CSEL_NEQ})], dl, MVT::[(${stackPointerType})], TVal, FVal, ConditionFlag ), 0);
164+
break;
165+
default:
166+
llvm_unreachable("unimplemented operand");
167+
}
168+
}
169+
170+
static SDValue lowerSelect2(SDValue Op, SelectionDAG &DAG)
171+
{
172+
Op->dump();
173+
SDValue Cond = Op.getOperand(0);
174+
SDValue LHS = Op.getOperand(1);
175+
SDValue RHS = Op.getOperand(2);
176+
SDLoc dl(Op);
177+
178+
auto Sub = DAG.getMachineNode([(${namespace})]::[(${SUBS})], dl, { MVT::[(${stackPointerType})], MVT::Other }, {Cond, Cond});
179+
SDValue ConditionFlag = SDValue(Sub, 1);
180+
181+
return SDValue(DAG.getMachineNode([(${namespace})]::[(${CSEL_EQ})], dl, MVT::[(${stackPointerType})], LHS, RHS, ConditionFlag), 0);
182+
}
183+
131184
[/th:block]
132185

133186
SDValue [(${namespace})]TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const
@@ -148,14 +201,21 @@ SDValue [(${namespace})]TargetLowering::LowerOperation(SDValue Op, SelectionDAG
148201
return lowerVASTART(Op, DAG);
149202
case ISD::VAARG:
150203
return lowerVAARG(Op, DAG);
151-
[#th:block th:if="${!hasConditionalMove}"]
152-
case ISD::SELECT:
153-
return lowerSelect(Op, DAG);
154-
[/th:block]
204+
case ISD::SETCC:
205+
return lowerSetcc(Op, DAG);
155206
[#th:block th:if="${mergedCmpAndBranch}"]
207+
case ISD::SELECT_CC:
208+
return lowerSelectcc(Op, DAG);
156209
case ISD::BR_CC:
157210
return lowerBR_CC(Op, DAG);
211+
case ISD::SELECT:
212+
return lowerSelect2(Op, DAG);
213+
[/th:block]
214+
[#th:block th:if="${!mergedCmpAndBranch}"]
215+
case ISD::SELECT:
216+
return lowerSelect(Op, DAG);
158217
[/th:block]
218+
159219
default : llvm_unreachable("unimplemented operand");
160220
}
161221
}

vadl/main/vadl/gcb/passes/MachineInstructionLabel.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ public enum MachineInstructionLabel {
100100
/*
101101
CONDITIONAL MOVE
102102
*/
103-
CMOVE_32,
104-
CMOVE_64;
105-
106-
103+
CSEL_EQ_I32,
104+
CSEL_EQ_I64,
105+
CSEL_NEQ_I32,
106+
CSEL_NEQ_I64
107107
}

vadl/main/vadl/lcb/passes/isaMatching/IsaMachineInstructionMatchingPass.java

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
import java.util.Objects;
6969
import java.util.Optional;
7070
import java.util.Set;
71+
import java.util.function.Predicate;
7172
import java.util.stream.Stream;
7273
import javax.annotation.Nullable;
7374
import vadl.configuration.GcbConfiguration;
@@ -82,7 +83,9 @@
8283
import vadl.types.BitsType;
8384
import vadl.types.BuiltInTable;
8485
import vadl.types.DataType;
86+
import vadl.types.SIntType;
8587
import vadl.types.Type;
88+
import vadl.viam.Constant;
8689
import vadl.viam.Counter;
8790
import vadl.viam.Instruction;
8891
import vadl.viam.InstructionSetArchitecture;
@@ -100,6 +103,7 @@
100103
import vadl.viam.graph.dependency.FieldAccessRefNode;
101104
import vadl.viam.graph.dependency.ReadMemNode;
102105
import vadl.viam.graph.dependency.ReadRegTensorNode;
106+
import vadl.viam.graph.dependency.SelectNode;
103107
import vadl.viam.graph.dependency.SignExtendNode;
104108
import vadl.viam.graph.dependency.SliceNode;
105109
import vadl.viam.graph.dependency.TruncateNode;
@@ -220,6 +224,22 @@ public Result execute(PassResults passResults, Specification viam) throws IOExce
220224
instruction.attachExtension(
221225
new MachineInstructionCtx(MachineInstructionLabel.SUB_RR_WITH_STATUS_REGISTER_32,
222226
Optional.empty()));
227+
} else if (findCSEL_EQ(originalGraph, Type.signedInt(32))) {
228+
instruction.attachExtension(
229+
new MachineInstructionCtx(MachineInstructionLabel.CSEL_EQ_I32,
230+
Optional.empty()));
231+
} else if (findCSEL_EQ(originalGraph, Type.signedInt(64))) {
232+
instruction.attachExtension(
233+
new MachineInstructionCtx(MachineInstructionLabel.CSEL_EQ_I64,
234+
Optional.empty()));
235+
} else if (findCSEL_NEQ(originalGraph, Type.signedInt(32))) {
236+
instruction.attachExtension(
237+
new MachineInstructionCtx(MachineInstructionLabel.CSEL_EQ_I32,
238+
Optional.empty()));
239+
} else if (findCSEL_NEQ(originalGraph, Type.signedInt(64))) {
240+
instruction.attachExtension(
241+
new MachineInstructionCtx(MachineInstructionLabel.CSEL_NEQ_I64,
242+
Optional.empty()));
223243
} else if (findRegisterRegisterOrRegisterImmediateOrImmediateRegister(behavior, SUB)) {
224244
instruction.attachExtension(new MachineInstructionCtx(MachineInstructionLabel.SUB, ty));
225245
} else if (findRegisterRegisterOrRegisterImmediateOrImmediateRegister(behavior,
@@ -340,6 +360,94 @@ public Result execute(PassResults passResults, Specification viam) throws IOExce
340360
return new Result(labels, flipIsaMatching(labels));
341361
}
342362

363+
private boolean findCSEL_32(Graph originalGraph, Constant constant) {
364+
var selectNode = originalGraph.getNodes(SelectNode.class).findFirst();
365+
var writes = originalGraph.getNodes(WritesRegisterTensor.class).toList();
366+
367+
if (writes.size() != 1) {
368+
return false;
369+
}
370+
371+
if (writes.stream().anyMatch(x -> !x.hasRegisterFile())) {
372+
return false;
373+
}
374+
375+
if (selectNode.isPresent()) {
376+
Predicate<Node> checkNode = (node) -> node instanceof ReadsRegisterTensor registerTensor
377+
&& registerTensor.hasRegisterFile();
378+
379+
if (checkNode.test(selectNode.get().trueCase())
380+
&& checkNode.test(selectNode.get().falseCase())
381+
&& selectNode.get().condition() instanceof BuiltInCall bc
382+
&& bc.builtIn() == EQU
383+
&& bc.arguments().get(1) instanceof ConstantNode constantNode
384+
&& constantNode.constant().equals(constant)) {
385+
if (originalGraph.getNodes(ReadsRegisterTensor.class).anyMatch(x -> x.registerTensor()
386+
.hasAnnotation(StatusRegisterAnnotation.ZeroStatusRegisterAnnotation.class))) {
387+
388+
return originalGraph.getNodes(TruncateNode.class)
389+
.anyMatch(x -> x.type().bitWidth() == Type.signedInt(32).bitWidth());
390+
}
391+
}
392+
}
393+
394+
return false;
395+
}
396+
397+
private boolean findCSEL_64(Graph originalGraph, Constant constant) {
398+
var selectNode = originalGraph.getNodes(SelectNode.class).findFirst();
399+
var writes = originalGraph.getNodes(WritesRegisterTensor.class).toList();
400+
401+
if (writes.size() != 1) {
402+
return false;
403+
}
404+
405+
if (writes.stream().anyMatch(x -> !x.hasRegisterFile())) {
406+
return false;
407+
}
408+
409+
if (selectNode.isPresent()) {
410+
Predicate<Node> checkNode = (node) -> node instanceof ReadsRegisterTensor registerTensor
411+
&& registerTensor.hasRegisterFile();
412+
413+
if (checkNode.test(selectNode.get().trueCase())
414+
&& checkNode.test(selectNode.get().falseCase())
415+
&& selectNode.get().condition() instanceof BuiltInCall bc
416+
&& bc.builtIn() == EQU
417+
&& bc.arguments().get(1) instanceof ConstantNode constantNode
418+
&& constantNode.constant().equals(constant)) {
419+
if (originalGraph.getNodes(ReadsRegisterTensor.class).anyMatch(x -> x.registerTensor()
420+
.hasAnnotation(StatusRegisterAnnotation.ZeroStatusRegisterAnnotation.class))) {
421+
422+
return originalGraph.getNodes(TruncateNode.class).toList().isEmpty();
423+
}
424+
}
425+
}
426+
427+
return false;
428+
}
429+
430+
431+
private boolean findCSEL_EQ(Graph originalGraph, SIntType ty) {
432+
if (ty.bitWidth() == 32) {
433+
return findCSEL_32(originalGraph, Constant.Value.one(DataType.bits(1)));
434+
} else if (ty.bitWidth() == 64) {
435+
return findCSEL_64(originalGraph, Constant.Value.one(DataType.bits(1)));
436+
}
437+
438+
return false;
439+
}
440+
441+
private boolean findCSEL_NEQ(Graph originalGraph, SIntType ty) {
442+
if (ty.bitWidth() == 32) {
443+
return findCSEL_32(originalGraph, Constant.Value.zero(DataType.bits(1)));
444+
} else if (ty.bitWidth() == 64) {
445+
return findCSEL_64(originalGraph, Constant.Value.zero(DataType.bits(1)));
446+
}
447+
448+
return false;
449+
}
450+
343451
private Optional<BitsType> getType(UninlinedGraph behavior) {
344452
var candidates = Stream.concat(
345453
behavior.getNodes(WriteRegTensorNode.class).filter(x -> x.regTensor().isRegisterFile())

vadl/main/vadl/lcb/template/lib/Target/EmitISelLoweringCppFilePass.java

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,9 @@ protected Map<String, Object> createVariables(final PassResults passResults,
9797
var framePointer = renderRegister(abi.framePointer().registerFile(), abi.framePointer().addr());
9898
var stackPointer = renderRegister(abi.stackPointer().registerFile(), abi.stackPointer().addr());
9999
var absoluteAddressLoadInstruction = abi.absoluteAddressLoad();
100-
var labelledMachineInstructions = ensureNonNull(
101-
(IsaMachineInstructionMatchingPass.Result) passResults.lastResultOf(
102-
IsaMachineInstructionMatchingPass.class),
103-
() -> Diagnostic.error("Cannot find semantics of the instructions",
104-
specification.location()))
105-
.labels();
106100
var coverageSummary =
107101
(ISelLoweringOperationActionPass.CoverageSummary) passResults.lastResultOf(
108102
ISelLoweringOperationActionPass.class);
109-
var hasCMove32 = labelledMachineInstructions.containsKey(MachineInstructionLabel.CMOVE_32);
110-
var hasCMove64 = labelledMachineInstructions.containsKey(MachineInstructionLabel.CMOVE_64);
111-
var conditionalMove = getConditionalMove(hasCMove32, hasCMove64, labelledMachineInstructions);
112103
var database = new Database(passResults, specification);
113104
var conditionalValueRange = getValueRangeCompareInstructions(database);
114105
var stackPointerType =
@@ -133,9 +124,6 @@ protected Map<String, Object> createVariables(final PassResults passResults,
133124
map.put("hasGlobalAddressLoad", abi.globalAddressLoad().isPresent());
134125
map.put("localAddressLoadInstruction",
135126
abi.localAddressLoad().map(x -> x.identifier().simpleName()).orElse(""));
136-
map.put("hasCMove32", hasCMove32);
137-
map.put("hasCMove64", hasCMove64);
138-
map.put("conditionalMove", conditionalMove);
139127
map.put("addImmediateInstruction", getAddImmediate(database));
140128
map.put("branchInstructions", getBranchInstructions(database));
141129
map.put("memoryInstructions", getMemoryInstructions(database));
@@ -177,6 +165,19 @@ protected Map<String, Object> createVariables(final PassResults passResults,
177165
new Query.Builder().machineInstructionLabel(
178166
MachineInstructionLabel.BSGEQ_BY_STATUS_REGISTER)
179167
.build())));
168+
map.put("CSEL_EQ",
169+
getFirstNameOrEmpty(database.run(
170+
new Query.Builder().machineInstructionLabel(
171+
stackPointerType == ValueType.I32
172+
? MachineInstructionLabel.CSEL_EQ_I32 :
173+
MachineInstructionLabel.CSEL_EQ_I64)
174+
.build())));
175+
map.put("CSEL_NEQ", getFirstNameOrEmpty(database.run(
176+
new Query.Builder().machineInstructionLabel(
177+
stackPointerType == ValueType.I32
178+
? MachineInstructionLabel.CSEL_NEQ_I32 :
179+
MachineInstructionLabel.CSEL_NEQ_I64)
180+
.build())));
180181
return map;
181182
}
182183

@@ -293,24 +294,6 @@ private List<BranchInstruction> getBranchInstructions(Database database) {
293294
}).toList();
294295
}
295296

296-
@Nullable
297-
private Instruction getConditionalMove(boolean hasCMove32,
298-
boolean hasCMove64,
299-
Map<MachineInstructionLabel,
300-
List<Instruction>> labelledMachineInstructions) {
301-
if (hasCMove64) {
302-
var cmove = labelledMachineInstructions.get(MachineInstructionLabel.CMOVE_32);
303-
ensureNonNull(cmove, "must not be null");
304-
return ensurePresent(cmove.stream().findFirst(), "At least one element should be present");
305-
} else if (hasCMove32) {
306-
var cmove = labelledMachineInstructions.get(MachineInstructionLabel.CMOVE_64);
307-
ensureNonNull(cmove, "must not be null");
308-
return ensurePresent(cmove.stream().findFirst(), "At least one element should be present");
309-
}
310-
311-
return null;
312-
}
313-
314297
private Map<String, Object> mapLlvmRegisterClass(LlvmRegisterFile registerFile) {
315298
return Map.of(
316299
"name", registerFile.simpleName(),

vadl/main/vadl/viam/graph/HasRegisterTensor.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package vadl.viam.graph;
1818

19+
import vadl.types.DataType;
1920
import vadl.viam.RegisterTensor;
2021
import vadl.viam.graph.dependency.ExpressionNode;
2122

vadl/main/vadl/viam/graph/ReadsRegisterTensor.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,14 @@
1616

1717
package vadl.viam.graph;
1818

19+
import vadl.types.DataType;
20+
1921
/**
2022
* Interface to indicate that the implementing class reads from a register file.
2123
*/
2224
public interface ReadsRegisterTensor extends HasRegisterTensor {
23-
25+
/**
26+
* Get the type of the node.
27+
*/
28+
DataType type();
2429
}

vadl/test/vadl/lcb/aarch64/IsaMachineInstructionMatchingAarch64PassTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ private static Stream<Arguments> getExpectedMatchings() {
6565
Optional.empty()),
6666
Arguments.of(List.of("SUBWS", "SUBWSSXTX", "SUBWSUXTX"),
6767
MachineInstructionLabel.SUB_RR_WITH_STATUS_REGISTER_32,
68+
Optional.empty()),
69+
Arguments.of(List.of("CSELEQX"),
70+
MachineInstructionLabel.CSEL_EQ_I64,
71+
Optional.empty()),
72+
Arguments.of(List.of("CSELNEX"),
73+
MachineInstructionLabel.CSEL_NEQ_I64,
6874
Optional.empty())
6975
);
7076
}

0 commit comments

Comments
 (0)