Skip to content

Commit eababdc

Browse files
committed
add support for single element vectors
1 parent 68e6fb4 commit eababdc

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,14 @@ struct IndexCastIndexI1Pattern final
622622
if (!isBoolScalarOrVector(op.getType()))
623623
return failure();
624624

625+
Type dstType = getTypeConverter()->convertType(op.getType());
626+
if (!dstType)
627+
return getTypeConversionFailure(rewriter, op);
628+
625629
Location loc = op.getLoc();
626630
Value zeroIdx =
627-
spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
628-
rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, op.getType(), zeroIdx,
631+
spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
632+
rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
629633
adaptor.getIn());
630634
return success();
631635
}

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,16 @@ func.func @index_castindexi1_1(%arg0 : index) {
743743
}
744744

745745
// CHECK-LABEL: index_castindexi1_2
746-
func.func @index_castindexi1_2(%arg0 : vector<3xindex>) {
746+
func.func @index_castindexi1_2(%arg0 : vector<1xindex>) -> vector<1xi1> {
747+
// Single-element vectors do not exist in SPIRV.
748+
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
749+
// CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32
750+
%0 = arith.index_cast %arg0 : vector<1xindex> to vector<1xi1>
751+
return %0 : vector<1xi1>
752+
}
753+
754+
// CHECK-LABEL: index_castindexi1_3
755+
func.func @index_castindexi1_3(%arg0 : vector<3xindex>) {
747756
// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32>
748757
// CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : vector<3xi32>
749758
%0 = arith.index_cast %arg0 : vector<3xindex> to vector<3xi1>

0 commit comments

Comments
 (0)