Skip to content

Commit 0afeed8

Browse files
authored
Attention LSE migraphx integration (#1887)
* Attention LSE (log-sum-exp): migraphx integration This PR introduces migraphx integration for the optional LSE output (attention kernels).
1 parent 39423b6 commit 0afeed8

File tree

14 files changed

+867
-49
lines changed

14 files changed

+867
-49
lines changed

mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,19 @@ def MIGraphX_ReduceSumOp :
561561
}];
562562
}
563563

564+
def MIGraphX_ReduceMaxOp
565+
: MIGraphX_Op<"reduce_max">,
566+
Arguments<(ins AnyMIXRShaped:$input, I64ArrayAttr:$axes)>,
567+
Results<(outs AnyMIXRShaped:$output)> {
568+
let summary = "Get the max of the values in given axis";
569+
let description = [{
570+
The `migraphx.reduce_max` op.
571+
}];
572+
let assemblyFormat = [{
573+
$input attr-dict `:` type($input) `->` type($output)
574+
}];
575+
}
576+
564577
//--------- Execution layer Ops
565578
def MIGraphX_CodeObjOp :
566579
MIGraphX_Op<"code_object">,

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def Rock_AttentionOp
209209
: Rock_Op<
210210
"attention", [DeclareOpInterfaceMethods<RockGemmGemmWrapperInterface>,
211211
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
212-
RockFusionRoot, AttrSizedOperandSegments]>,
212+
RockFusionRoot, AttrSizedOperandSegments,
213+
AttrSizedResultSegments]>,
213214
Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
214215
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys,
215216
TensorOrMemRefOf<[F32, F16, BF16]>:$values,
@@ -223,7 +224,8 @@ def Rock_AttentionOp
223224
OptionalAttr<RockTuningParamAttrInterface>:$params0,
224225
OptionalAttr<RockTuningParamAttrInterface>:$params1,
225226
I32Attr:$firstGemmIdx)>,
226-
Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
227+
Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result,
228+
Optional<TensorOf<[F32, F16, BF16]>>:$lseOut)> {
227229
let summary = "Attention operation of transformer models";
228230
let description = [{
229231
Performs the operation out = SOFTMAX(preSoftmaxBody(queries * keys, preSoftmaxElemWiseInputs)) * values.
@@ -261,7 +263,12 @@ def Rock_AttentionOp
261263
(`lse` `=` $lse^ `:` type($lse) `\n`)?
262264
(`qk` `=` `elementwise` (`otherIns` `(` $preSoftmaxElemWiseInputs^ `:` type($preSoftmaxElemWiseInputs) `)`)? $preSoftmaxBody^ `\n`)?
263265
(`tr` $oTransposed^)? $out `=` `softmax` `(` `qk` `)` `*` (`tr` $vTransposed^)? $values `:` type($values) `->` type($out) `\n`
264-
`}` attr-dict (`->` type($result)^)?
266+
`}` attr-dict (`->` type($result)^)? (`,` type($lseOut)^)?
267+
}];
268+
269+
// Return operand of LSE if it exists.
270+
let extraClassDeclaration = [{
271+
::mlir::OpOperand* getOutLseArgument() { return getLse() ? &(*this)->getOpOperand(getNumOperands() - 1) : nullptr; }
265272
}];
266273
}
267274

mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -738,16 +738,20 @@ struct ReduceMeanConverter final
738738
ConversionPatternRewriter &rewriter) const final;
739739
};
740740

741-
struct ReduceSumConverter final
742-
: public OpConversionPattern<migraphx::ReduceSumOp> {
743-
using OpConversionPattern<migraphx::ReduceSumOp>::OpConversionPattern;
741+
namespace {
742+
template <typename MIGraphXOp, typename TosaOp>
743+
struct ReduceConverter final : public OpConversionPattern<MIGraphXOp> {
744+
using OpConversionPattern<MIGraphXOp>::OpConversionPattern;
745+
using OpAdaptor = typename OpConversionPattern<MIGraphXOp>::OpAdaptor;
744746

745747
LogicalResult
746-
matchAndRewrite(migraphx::ReduceSumOp op, OpAdaptor adaptor,
748+
matchAndRewrite(MIGraphXOp op, OpAdaptor adaptor,
747749
ConversionPatternRewriter &rewriter) const final;
748750
};
749751
} // namespace
750752

753+
} // namespace
754+
751755
tosa::ConstOp ReduceMeanConverter::createNumElementsTosaConst(
752756
Location loc, TypedValue<RankedTensorType> inputTensor,
753757
IntegerAttr axisAttr, ConversionPatternRewriter &rewriter) const {
@@ -805,9 +809,10 @@ LogicalResult ReduceMeanConverter::matchAndRewrite(
805809
return success();
806810
}
807811

808-
LogicalResult
809-
ReduceSumConverter::matchAndRewrite(migraphx::ReduceSumOp op, OpAdaptor adaptor,
810-
ConversionPatternRewriter &rewriter) const {
812+
template <typename MIGraphXOp, typename TosaOp>
813+
LogicalResult ReduceConverter<MIGraphXOp, TosaOp>::matchAndRewrite(
814+
MIGraphXOp op, OpAdaptor adaptor,
815+
ConversionPatternRewriter &rewriter) const {
811816
Location loc = op.getLoc();
812817
ArrayRef<Attribute> axes = op.getAxes().getValue();
813818
if (axes.size() != 1) {
@@ -817,9 +822,9 @@ ReduceSumConverter::matchAndRewrite(migraphx::ReduceSumOp op, OpAdaptor adaptor,
817822
rewriter.getI32IntegerAttr(cast<IntegerAttr>(axes[0]).getInt());
818823
auto input = cast<TypedValue<RankedTensorType>>(adaptor.getInput());
819824
Type elementType = input.getType().getElementType();
820-
auto tosaReduceSum = createOpAndInfer<tosa::ReduceSumOp>(
821-
rewriter, loc, elementType, input, axis);
822-
rewriter.replaceOp(op, tosaReduceSum);
825+
auto tosaReduce =
826+
createOpAndInfer<TosaOp>(rewriter, loc, elementType, input, axis);
827+
rewriter.replaceOp(op, tosaReduce);
823828
return success();
824829
}
825830

@@ -1493,7 +1498,9 @@ void migraphx::populateMIGraphXToTosaConversionPatterns(
14931498
DotConverter<DotOp>, DotConverter<QuantDotOp>,
14941499
BroadcastConverter, MultiBroadcastConverter, TransposeConverter,
14951500
ReshapeConverter, SliceConverter, ReduceMeanConverter,
1496-
ReduceSumConverter, TrivialConverter<AddOp, tosa::AddOp>,
1501+
ReduceConverter<ReduceSumOp, tosa::ReduceSumOp>,
1502+
ReduceConverter<ReduceMaxOp, tosa::ReduceMaxOp>,
1503+
TrivialConverter<AddOp, tosa::AddOp>,
14971504
TrivialConverter<SubOp, tosa::SubOp>,
14981505
TrivialConverter<PowOp, tosa::PowOp>, DivConverter, MulConverter,
14991506
TrivialConverter<AbsOp, tosa::AbsOp>,

0 commit comments

Comments
 (0)