5151#include " absl/log/log.h"
5252#include " absl/status/status.h"
5353#include " absl/types/span.h"
54+ #include " llvm/include/llvm/ADT/APInt.h"
5455#include " mlir/include/mlir/Dialect/Arith/IR/Arith.h"
5556#include " mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
5657#include " mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
58+ #include " mlir/include/mlir/IR/Attributes.h"
5759#include " mlir/include/mlir/IR/Builders.h"
5860#include " mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
5961#include " mlir/include/mlir/IR/OperationSupport.h"
@@ -554,7 +556,7 @@ VectorType getNativeVregType(Type elem_ty,
554556FailureOr<Value> maskOOB (RewriteContext &ctx, OpBuilder &builder,
555557 TypedValue<VectorType> value,
556558 const VRegDataBounds &bounds,
557- const TypedAttr neutral) {
559+ const Attribute neutral) {
558560 auto native_vreg_ty =
559561 getNativeVregType (value.getType ().getElementType (), ctx.target_shape );
560562 TPU_ASSERT_LOC (value.getLoc (), llvm::equal (value.getType ().getShape (),
@@ -3926,6 +3928,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
39263928 ImplicitLocOpBuilder builder (op.getLoc (), &op);
39273929 auto multi_reduction_op = cast<vector::MultiDimReductionOp>(op);
39283930 const VectorType src_ty = multi_reduction_op.getSourceVectorType ();
3931+ auto element_type = src_ty.getElementType ();
39293932 int64_t src_rank = src_ty.getRank ();
39303933 const auto res_ty = dyn_cast<VectorType>(multi_reduction_op.getDestType ());
39313934 if (res_ty == nullptr ) {
@@ -3953,44 +3956,56 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
39533956 return multi_reduction_op.emitOpError (
39543957 " Not implemented: Only constant accumulator supported" );
39553958 }
3956- if (!src_ty.getElementType ().isF32 () && !src_ty.getElementType ().isBF16 ()) {
3959+ if (!element_type.isF32 () && !element_type.isBF16 () &&
3960+ !element_type.isSignlessInteger ((32 ))) {
39573961 return multi_reduction_op.emitOpError (
3958- " Not implemented: Only FP32 and BF16 reductions supported, but "
3959- " got " )
3960- << src_ty;
3962+ " Not implemented: unsupported element type" );
39613963 }
3962- auto element_type = cast<FloatType>(src_ty. getElementType () );
3963- const auto acc_def_value = dyn_cast<DenseFPElementsAttr >(acc_def.getValue ());
3964+ bool is_int = element_type. isSignlessInteger ( 32 );
3965+ const auto acc_def_value = dyn_cast<DenseElementsAttr >(acc_def.getValue ());
39643966 if (acc_def_value == nullptr || !acc_def_value.isSplat ()) {
39653967 return multi_reduction_op.emitOpError (" Expected a splat constant" );
39663968 }
39673969 TPU_ASSERT_OP (acc_def_value.getElementType () == element_type);
3968- const auto val = acc_def_value.getSplatValue <FloatAttr>();
3969- FloatAttr neutral;
3970+ Attribute neutral;
39703971 switch (multi_reduction_op.getKind ()) {
39713972 case vector::CombiningKind::ADD:
3972- neutral = builder.getFloatAttr (element_type, 0 );
3973+ neutral = builder.getZeroAttr (element_type);
39733974 break ;
39743975 case vector::CombiningKind::MAXIMUMF: {
39753976 // TODO(b/322836633): The semantics of maximumf don't match the lowering
39763977 // for older TPU versions because older TPU versions don't respect the
39773978 // -0.0 vs +0.0 ordering.
39783979 neutral = builder.getFloatAttr (
3979- element_type, APFloat::getInf (element_type.getFloatSemantics (),
3980- /* Negative=*/ true ));
3980+ element_type,
3981+ APFloat::getInf (cast<FloatType>(element_type).getFloatSemantics (),
3982+ /* Negative=*/ true ));
39813983 } break ;
39823984 case vector::CombiningKind::MINIMUMF: {
39833985 neutral = builder.getFloatAttr (
3984- element_type, APFloat::getInf (element_type.getFloatSemantics (),
3985- /* Negative=*/ false ));
3986+ element_type,
3987+ APFloat::getInf (cast<FloatType>(element_type).getFloatSemantics (),
3988+ /* Negative=*/ false ));
3989+ } break ;
3990+ case vector::CombiningKind::MAXSI: {
3991+ neutral = builder.getIntegerAttr (
3992+ element_type,
3993+ APInt::getSignedMinValue (element_type.getIntOrFloatBitWidth ()));
3994+ } break ;
3995+ case vector::CombiningKind::MINSI: {
3996+ neutral = builder.getIntegerAttr (
3997+ element_type,
3998+ APInt::getSignedMaxValue (element_type.getIntOrFloatBitWidth ()));
39863999 } break ;
39874000 default :
39884001 return multi_reduction_op.emitOpError (
39894002 " Not implemented: unsupported kind" );
39904003 }
3991- if (val != neutral) {
4004+ if (auto val = acc_def_value. getSplatValue <Attribute>(); val != neutral) {
39924005 return multi_reduction_op.emitOpError (
3993- " Not implemented: Only neutral accumulator supported" );
4006+ " Not implemented: Only neutral accumulator supported for "
4007+ " float reduction. Expected " )
4008+ << neutral << " , but got " << val;
39944009 }
39954010
39964011 std::array<bool , 2 > reduces;
@@ -4074,9 +4089,11 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
40744089 tpu_kind = tpu::ReductionKind::SUM;
40754090 break ;
40764091 case vector::CombiningKind::MAXIMUMF:
4092+ case vector::CombiningKind::MAXSI:
40774093 tpu_kind = tpu::ReductionKind::MAX;
40784094 break ;
40794095 case vector::CombiningKind::MINIMUMF:
4096+ case vector::CombiningKind::MINSI:
40804097 tpu_kind = tpu::ReductionKind::MIN;
40814098 break ;
40824099 default :
@@ -4103,14 +4120,29 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
41034120 src_vregs.Slice (src_slice_start, src_slice_end);
41044121 std::optional<Value> acc_vreg;
41054122 auto reduce_elementwise = [&](Value lhs, Value rhs) -> Value {
4123+ Value result;
41064124 switch (tpu_kind) {
41074125 case tpu::ReductionKind::SUM:
4108- return builder.create <arith::AddFOp>(loc, lhs, rhs);
4126+ result =
4127+ is_int
4128+ ? builder.create <arith::AddIOp>(loc, lhs, rhs).getResult ()
4129+ : builder.create <arith::AddFOp>(loc, lhs, rhs)
4130+ .getResult ();
4131+ break ;
41094132 case tpu::ReductionKind::MAX:
4110- return builder.create <arith::MaximumFOp>(loc, lhs, rhs);
4133+ result = is_int ? builder.create <arith::MaxSIOp>(loc, lhs, rhs)
4134+ .getResult ()
4135+ : builder.create <arith::MaximumFOp>(loc, lhs, rhs)
4136+ .getResult ();
4137+ break ;
41114138 case tpu::ReductionKind::MIN:
4112- return builder.create <arith::MinimumFOp>(loc, lhs, rhs);
4139+ result = is_int ? builder.create <arith::MinSIOp>(loc, lhs, rhs)
4140+ .getResult ()
4141+ : builder.create <arith::MinimumFOp>(loc, lhs, rhs)
4142+ .getResult ();
4143+ break ;
41134144 }
4145+ return result;
41144146 };
41154147 auto reduction_status = reduced_vregs.EachStatus (
41164148 [&](const absl::Span<const int64_t > red_idx,
0 commit comments