Skip to content

Commit 8971414

Browse files
authored
[mlir][spirv] Add floating point dot product (#73466)
Because `OpDot` does not require any extra capabilities or extensions, enable it by default in the vector to spirv conversion.
1 parent a369a59 commit 8971414

File tree

5 files changed

+141
-6
lines changed

5 files changed

+141
-6
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,40 @@ def SPIRV_ISubBorrowOp : SPIRV_ArithmeticExtendedBinaryOp<"ISubBorrow",
503503

504504
// -----
505505

506+
def SPIRV_DotOp : SPIRV_Op<"Dot",
507+
[Pure, AllTypesMatch<["vector1", "vector2"]>,
508+
AllElementTypesMatch<["vector1", "result"]>]> {
509+
let summary = "Dot product of Vector 1 and Vector 2";
510+
511+
let description = [{
512+
Result Type must be a floating point scalar.
513+
514+
Vector 1 and Vector 2 must be vectors of the same type, and their component
515+
type must be Result Type.
516+
517+
#### Example:
518+
519+
```mlir
520+
%0 = spirv.Dot %v1, %v2 : vector<4xf32> -> f32
521+
```
522+
}];
523+
524+
let arguments = (ins
525+
SPIRV_VectorOf<SPIRV_Float>:$vector1,
526+
SPIRV_VectorOf<SPIRV_Float>:$vector2
527+
);
528+
529+
let results = (outs
530+
SPIRV_Float:$result
531+
);
532+
533+
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
534+
535+
let hasVerifier = 0;
536+
}
537+
538+
// -----
539+
506540
def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
507541
SPIRV_Integer,
508542
[UsableInSpecConstantOp]> {

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4205,11 +4205,14 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
42054205
"::llvm::cast<::mlir::spirv::JointMatrixINTELType>($_self).getElementType()",
42064206
"Joint Matrix">;
42074207

4208+
class SPIRV_VectorOf<Type type> :
4209+
VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
4210+
42084211
class SPIRV_ScalarOrVectorOf<Type type> :
4209-
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
4212+
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
42104213

42114214
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
4212-
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
4215+
AnyTypeOf<[type, SPIRV_VectorOf<type>,
42134216
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
42144217

42154218
class SPIRV_MatrixOrCoopMatrixOf<Type type> :
@@ -4357,6 +4360,7 @@ def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
43574360
def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
43584361
def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
43594362
def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
4363+
def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>;
43604364
def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
43614365
def SPIRV_OC_OpISubBorrow : I32EnumAttrCase<"OpISubBorrow", 150>;
43624366
def SPIRV_OC_OpUMulExtended : I32EnumAttrCase<"OpUMulExtended", 151>;
@@ -4526,7 +4530,7 @@ def SPIRV_OpcodeAttr :
45264530
SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
45274531
SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
45284532
SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
4529-
SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
4533+
SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpDot,
45304534
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpIAddCarry, SPIRV_OC_OpISubBorrow,
45314535
SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, SPIRV_OC_OpIsNan,
45324536
SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,8 @@ struct VectorStoreOpConverter final
646646
}
647647
};
648648

649-
struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
649+
struct VectorReductionToIntDotProd final
650+
: OpRewritePattern<vector::ReductionOp> {
650651
using OpRewritePattern::OpRewritePattern;
651652

652653
LogicalResult matchAndRewrite(vector::ReductionOp op,
@@ -740,6 +741,36 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
740741
}
741742
};
742743

744+
struct VectorReductionToFPDotProd final
745+
: OpConversionPattern<vector::ReductionOp> {
746+
using OpConversionPattern::OpConversionPattern;
747+
748+
LogicalResult
749+
matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
750+
ConversionPatternRewriter &rewriter) const override {
751+
if (op.getKind() != vector::CombiningKind::ADD)
752+
return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
753+
754+
auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
755+
if (!resultType)
756+
return rewriter.notifyMatchFailure(op, "result is not a float");
757+
758+
auto mul = adaptor.getVector().getDefiningOp<arith::MulFOp>();
759+
if (!mul)
760+
return rewriter.notifyMatchFailure(
761+
op, "reduction operand is not 'arith.mulf'");
762+
763+
Location loc = op.getLoc();
764+
Value res = rewriter.create<spirv::DotOp>(loc, resultType, mul.getLhs(),
765+
mul.getRhs());
766+
if (op.getAcc())
767+
res = rewriter.create<spirv::FAddOp>(loc, adaptor.getAcc(), res);
768+
769+
rewriter.replaceOp(op, res);
770+
return success();
771+
}
772+
};
773+
743774
} // namespace
744775
#define CL_INT_MAX_MIN_OPS \
745776
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -763,10 +794,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
763794
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
764795
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
765796
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
766-
typeConverter, patterns.getContext());
797+
typeConverter, patterns.getContext(), PatternBenefit(1));
798+
799+
// Make sure that the more specialized dot product pattern has higher benefit
800+
// than the generic one that extracts all elements.
801+
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
802+
PatternBenefit(2));
767803
}
768804

769805
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
770806
RewritePatternSet &patterns) {
771-
patterns.add<VectorReductionToDotProd>(patterns.getContext());
807+
patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
772808
}

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,31 @@ func.func @reduction_add(%v : vector<4xi32>) -> i32 {
500500

501501
// -----
502502

503+
// CHECK-LABEL: func @reduction_addf
504+
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
505+
// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
506+
// CHECK: return %[[DOT]] : f32
507+
func.func @reduction_addf(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
508+
%mul = arith.mulf %arg0, %arg1 : vector<4xf32>
509+
%red = vector.reduction <add>, %mul : vector<4xf32> into f32
510+
return %red : f32
511+
}
512+
513+
// -----
514+
515+
// CHECK-LABEL: func @reduction_addf_acc
516+
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ACC:.+]]: f32)
517+
// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xf32> -> f32
518+
// CHECK: %[[RES:.+]] = spirv.FAdd %[[ACC]], %[[DOT]] : f32
519+
// CHECK: return %[[RES]] : f32
520+
func.func @reduction_addf_acc(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %acc: f32) -> f32 {
521+
%mul = arith.mulf %arg0, %arg1 : vector<4xf32>
522+
%red = vector.reduction <add>, %mul, %acc : vector<4xf32> into f32
523+
return %red : f32
524+
}
525+
526+
// -----
527+
503528
// CHECK-LABEL: func @reduction_mul
504529
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
505530
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>

mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,42 @@ func.func @isub_borrow(%arg: i64) -> !spirv.struct<(i32, i32)> {
254254

255255
// -----
256256

257+
//===----------------------------------------------------------------------===//
258+
// spirv.Dot
259+
//===----------------------------------------------------------------------===//
260+
261+
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
262+
%0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f32
263+
return %0 : f32
264+
}
265+
266+
// -----
267+
268+
// expected-note @+1 {{prior use here}}
269+
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
270+
// expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
271+
%0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f32
272+
return %0 : f32
273+
}
274+
275+
// -----
276+
277+
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
278+
// expected-error @+1 {{'spirv.Dot' op failed to verify that all of {vector1, result} have same element type}}
279+
%0 = spirv.Dot %arg0, %arg1 : vector<4xf32> -> f16
280+
return %0 : f16
281+
}
282+
283+
// -----
284+
285+
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
286+
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
287+
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
288+
return %0 : i32
289+
}
290+
291+
// -----
292+
257293
//===----------------------------------------------------------------------===//
258294
// spirv.SMulExtended
259295
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)