Skip to content

Commit 230b189

Browse files
authored
[mlir][spirv] Add folding for [S|U|LessThan[Equal] (#85435)
Add missing constant propogation folder for [S|U]LessThan[Equal]. Implement additional folding when the operands are equal for all ops. Allows for constant folding in the IndexToSPIRV pass. Part of work #70704
1 parent cceedc9 commit 230b189

File tree

3 files changed

+266
-0
lines changed

3 files changed

+266
-0
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,8 @@ def SPIRV_SLessThanOp : SPIRV_LogicalBinaryOp<"SLessThan",
716716

717717
```
718718
}];
719+
720+
let hasFolder = 1;
719721
}
720722

721723
// -----
@@ -745,6 +747,8 @@ def SPIRV_SLessThanEqualOp : SPIRV_LogicalBinaryOp<"SLessThanEqual",
745747
%5 = spirv.SLessThanEqual %2, %3 : vector<4xi32>
746748
```
747749
}];
750+
751+
let hasFolder = 1;
748752
}
749753

750754
// -----
@@ -886,6 +890,8 @@ def SPIRV_ULessThanOp : SPIRV_LogicalBinaryOp<"ULessThan",
886890
%5 = spirv.ULessThan %2, %3 : vector<4xi32>
887891
```
888892
}];
893+
894+
let hasFolder = 1;
889895
}
890896

891897
// -----
@@ -949,6 +955,8 @@ def SPIRV_ULessThanEqualOp : SPIRV_LogicalBinaryOp<"ULessThanEqual",
949955
%5 = spirv.ULessThanEqual %2, %3 : vector<4xi32>
950956
```
951957
}];
958+
959+
let hasFolder = 1;
952960
}
953961

954962
#endif // MLIR_DIALECT_SPIRV_IR_LOGICAL_OPS

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,88 @@ OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
880880
});
881881
}
882882

883+
//===----------------------------------------------------------------------===//
884+
// spirv.SLessThan
885+
//===----------------------------------------------------------------------===//
886+
887+
OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
888+
// x == x -> false
889+
if (getOperand1() == getOperand2()) {
890+
auto falseAttr = BoolAttr::get(getContext(), false);
891+
if (isa<IntegerType>(getType()))
892+
return falseAttr;
893+
if (auto vecTy = dyn_cast<VectorType>(getType()))
894+
return SplatElementsAttr::get(vecTy, falseAttr);
895+
}
896+
897+
return constFoldBinaryOp<IntegerAttr>(
898+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
899+
return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
900+
});
901+
}
902+
903+
//===----------------------------------------------------------------------===//
904+
// spirv.SLessThanEqual
905+
//===----------------------------------------------------------------------===//
906+
907+
OpFoldResult
908+
spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
909+
// x == x -> true
910+
if (getOperand1() == getOperand2()) {
911+
auto trueAttr = BoolAttr::get(getContext(), true);
912+
if (isa<IntegerType>(getType()))
913+
return trueAttr;
914+
if (auto vecTy = dyn_cast<VectorType>(getType()))
915+
return SplatElementsAttr::get(vecTy, trueAttr);
916+
}
917+
918+
return constFoldBinaryOp<IntegerAttr>(
919+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
920+
return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
921+
});
922+
}
923+
924+
//===----------------------------------------------------------------------===//
925+
// spirv.ULessThan
926+
//===----------------------------------------------------------------------===//
927+
928+
OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
929+
// x == x -> false
930+
if (getOperand1() == getOperand2()) {
931+
auto falseAttr = BoolAttr::get(getContext(), false);
932+
if (isa<IntegerType>(getType()))
933+
return falseAttr;
934+
if (auto vecTy = dyn_cast<VectorType>(getType()))
935+
return SplatElementsAttr::get(vecTy, falseAttr);
936+
}
937+
938+
return constFoldBinaryOp<IntegerAttr>(
939+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
940+
return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
941+
});
942+
}
943+
944+
//===----------------------------------------------------------------------===//
945+
// spirv.ULessThanEqual
946+
//===----------------------------------------------------------------------===//
947+
948+
OpFoldResult
949+
spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
950+
// x == x -> true
951+
if (getOperand1() == getOperand2()) {
952+
auto trueAttr = BoolAttr::get(getContext(), true);
953+
if (isa<IntegerType>(getType()))
954+
return trueAttr;
955+
if (auto vecTy = dyn_cast<VectorType>(getType()))
956+
return SplatElementsAttr::get(vecTy, trueAttr);
957+
}
958+
959+
return constFoldBinaryOp<IntegerAttr>(
960+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
961+
return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
962+
});
963+
}
964+
883965
//===----------------------------------------------------------------------===//
884966
// spirv.ShiftLeftLogical
885967
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,182 @@ func.func @const_fold_vector_inotequal() -> vector<3xi1> {
14781478

14791479
// -----
14801480

1481+
//===----------------------------------------------------------------------===//
1482+
// spirv.SLessThan
1483+
//===----------------------------------------------------------------------===//
1484+
1485+
// CHECK-LABEL: @slt_same
1486+
func.func @slt_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
1487+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1488+
// CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
1489+
%0 = spirv.SLessThan %arg0, %arg0 : i32
1490+
%1 = spirv.SLessThan %arg1, %arg1 : vector<3xi32>
1491+
1492+
// CHECK: return %[[CFALSE]], %[[CVFALSE]]
1493+
return %0, %1 : i1, vector<3xi1>
1494+
}
1495+
1496+
// CHECK-LABEL: @const_fold_scalar_slt
1497+
func.func @const_fold_scalar_slt() -> (i1, i1) {
1498+
%c4 = spirv.Constant 4 : i32
1499+
%c5 = spirv.Constant 5 : i32
1500+
%c6 = spirv.Constant 6 : i32
1501+
1502+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1503+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1504+
%0 = spirv.SLessThan %c5, %c6 : i32
1505+
%1 = spirv.SLessThan %c5, %c4 : i32
1506+
1507+
// CHECK: return %[[CTRUE]], %[[CFALSE]]
1508+
return %0, %1 : i1, i1
1509+
}
1510+
1511+
// CHECK-LABEL: @const_fold_vector_slt
1512+
func.func @const_fold_vector_slt() -> vector<3xi1> {
1513+
%cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
1514+
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
1515+
1516+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[false, true, false]>
1517+
%0 = spirv.SLessThan %cv0, %cv1 : vector<3xi32>
1518+
1519+
// CHECK: return %[[RET]]
1520+
return %0 : vector<3xi1>
1521+
}
1522+
1523+
// -----
1524+
1525+
//===----------------------------------------------------------------------===//
1526+
// spirv.SLessThanEqual
1527+
//===----------------------------------------------------------------------===//
1528+
1529+
// CHECK-LABEL: @sle_same
1530+
func.func @sle_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
1531+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1532+
// CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
1533+
%0 = spirv.SLessThanEqual %arg0, %arg0 : i32
1534+
%1 = spirv.SLessThanEqual %arg1, %arg1 : vector<3xi32>
1535+
1536+
// CHECK: return %[[CTRUE]], %[[CVTRUE]]
1537+
return %0, %1 : i1, vector<3xi1>
1538+
}
1539+
1540+
// CHECK-LABEL: @const_fold_scalar_sle
1541+
func.func @const_fold_scalar_sle() -> (i1, i1) {
1542+
%c4 = spirv.Constant 4 : i32
1543+
%c5 = spirv.Constant 5 : i32
1544+
%c6 = spirv.Constant 6 : i32
1545+
1546+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1547+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1548+
%0 = spirv.SLessThanEqual %c5, %c6 : i32
1549+
%1 = spirv.SLessThanEqual %c5, %c4 : i32
1550+
1551+
// CHECK: return %[[CTRUE]], %[[CFALSE]]
1552+
return %0, %1 : i1, i1
1553+
}
1554+
1555+
// CHECK-LABEL: @const_fold_vector_sle
1556+
func.func @const_fold_vector_sle() -> vector<3xi1> {
1557+
%cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
1558+
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
1559+
1560+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[true, true, false]>
1561+
%0 = spirv.SLessThanEqual %cv0, %cv1 : vector<3xi32>
1562+
1563+
// CHECK: return %[[RET]]
1564+
return %0 : vector<3xi1>
1565+
}
1566+
1567+
// -----
1568+
1569+
//===----------------------------------------------------------------------===//
1570+
// spirv.ULessThan
1571+
//===----------------------------------------------------------------------===//
1572+
1573+
// CHECK-LABEL: @ult_same
1574+
func.func @ult_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
1575+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1576+
// CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
1577+
%0 = spirv.ULessThan %arg0, %arg0 : i32
1578+
%1 = spirv.ULessThan %arg1, %arg1 : vector<3xi32>
1579+
1580+
// CHECK: return %[[CFALSE]], %[[CVFALSE]]
1581+
return %0, %1 : i1, vector<3xi1>
1582+
}
1583+
1584+
// CHECK-LABEL: @const_fold_scalar_ult
1585+
func.func @const_fold_scalar_ult() -> (i1, i1) {
1586+
%c4 = spirv.Constant 4 : i32
1587+
%c5 = spirv.Constant 5 : i32
1588+
%cn6 = spirv.Constant -6 : i32
1589+
1590+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1591+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1592+
%0 = spirv.ULessThan %c5, %cn6 : i32
1593+
%1 = spirv.ULessThan %c5, %c4 : i32
1594+
1595+
// CHECK: return %[[CTRUE]], %[[CFALSE]]
1596+
return %0, %1 : i1, i1
1597+
}
1598+
1599+
// CHECK-LABEL: @const_fold_vector_ult
1600+
func.func @const_fold_vector_ult() -> vector<3xi1> {
1601+
%cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
1602+
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
1603+
1604+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[false, true, false]>
1605+
%0 = spirv.ULessThan %cv0, %cv1 : vector<3xi32>
1606+
1607+
// CHECK: return %[[RET]]
1608+
return %0 : vector<3xi1>
1609+
}
1610+
1611+
// -----
1612+
1613+
//===----------------------------------------------------------------------===//
1614+
// spirv.ULessThanEqual
1615+
//===----------------------------------------------------------------------===//
1616+
1617+
// CHECK-LABEL: @ule_same
1618+
func.func @ule_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
1619+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1620+
// CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
1621+
%0 = spirv.ULessThanEqual %arg0, %arg0 : i32
1622+
%1 = spirv.ULessThanEqual %arg1, %arg1 : vector<3xi32>
1623+
1624+
// CHECK: return %[[CTRUE]], %[[CVTRUE]]
1625+
return %0, %1 : i1, vector<3xi1>
1626+
}
1627+
1628+
// CHECK-LABEL: @const_fold_scalar_ule
1629+
func.func @const_fold_scalar_ule() -> (i1, i1) {
1630+
%c4 = spirv.Constant 4 : i32
1631+
%c5 = spirv.Constant 5 : i32
1632+
%cn6 = spirv.Constant -6 : i32
1633+
1634+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1635+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1636+
%0 = spirv.ULessThanEqual %c5, %cn6 : i32
1637+
%1 = spirv.ULessThanEqual %c5, %c4 : i32
1638+
1639+
// CHECK: return %[[CTRUE]], %[[CFALSE]]
1640+
return %0, %1 : i1, i1
1641+
}
1642+
1643+
// CHECK-LABEL: @const_fold_vector_ule
1644+
func.func @const_fold_vector_ule() -> vector<3xi1> {
1645+
%cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
1646+
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
1647+
1648+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[true, true, false]>
1649+
%0 = spirv.ULessThanEqual %cv0, %cv1 : vector<3xi32>
1650+
1651+
// CHECK: return %[[RET]]
1652+
return %0 : vector<3xi1>
1653+
}
1654+
1655+
// -----
1656+
14811657
//===----------------------------------------------------------------------===//
14821658
// spirv.LeftShiftLogical
14831659
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)