66//
77// ===----------------------------------------------------------------------===//
88#include " gc/Analysis/VectorBasedFusionAnalysis.h"
9- #include " mlir /Dialect/Linalg/IR/Linalg .h"
9+ #include " gc /Dialect/Linalgx/Utils .h"
1010
1111namespace mlir {
1212namespace gc {
@@ -22,16 +22,16 @@ namespace gc {
2222 arith::TruncFOp, arith::TruncIOp
2323
2424#define NOT_NEED_TO_PROCESS_OP \
25- linalg::GenericOp, linalg::BatchReduceMatmulOp, linalg::MatmulOp, \
26- linalg::BatchMatmulOp, linalg::BatchMatmulTransposeAOp, \
27- linalg::BatchMatmulTransposeBOp, linalg::MatmulTransposeAOp, \
28- linalg::MatmulTransposeBOp, linalg::QuantizedBatchMatmulOp, \
29- linalg::QuantizedMatmulOp, tensor::CollapseShapeOp, \
30- tensor::ExpandShapeOp, tensor::ExtractSliceOp, tensor::InsertSliceOp, \
31- microkernel::BrgemmOp
25+ linalg::BatchReduceMatmulOp, linalg::MatmulOp, linalg::BatchMatmulOp, \
26+ linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \
27+ linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \
28+ linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \
29+ tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::ExtractSliceOp, \
30+ tensor::InsertSliceOp, microkernel::BrgemmOp
3231
3332static inline bool isNotNeedToProcessOp (Operation *op) {
34- return isa<NOT_NEED_TO_PROCESS_OP>(op);
33+ return isa<NOT_NEED_TO_PROCESS_OP>(op) or
34+ linalgx::isAnyGenericPackedMatmulOp (op);
3535}
3636
3737static inline bool isSpecialOp (Operation *op) {
@@ -72,7 +72,7 @@ void shapeCastSourceAxis(const ArrayRef<int64_t> &a, const ArrayRef<int64_t> &b,
7272 while (dimB < dimA && j < rankB)
7373 dimB *= b[j++];
7474 if (dimA != dimB) {
75- assert ( false && " Invalid shape cast operation." );
75+ llvm::llvm_unreachable_internal ( " Invalid shape cast operation." );
7676 break ;
7777 }
7878 if (bAxisBegin != j) {
@@ -87,12 +87,13 @@ void shapeCastSourceAxis(const ArrayRef<int64_t> &a, const ArrayRef<int64_t> &b,
8787 if (j < rankB && all_of (b.slice (j), isOne))
8888 j = rankB;
8989 }
90-
91- assert (i == rankA && j == rankB && " Invalid shapecast operation." );
90+ if (i != rankA or j != rankB)
91+ llvm_unreachable ( " Invalid shapecast operation." );
9292}
9393
9494bool isScalar (Type type) {
95- assert (type && " Not a valid type" );
95+ if (not type)
96+ llvm_unreachable (" Not a valid type" );
9697 if (auto vecType = dyn_cast<VectorType>(type))
9798 return false ;
9899 if (auto tensorType = dyn_cast<TensorType>(type))
@@ -107,8 +108,8 @@ void getSrcBroadcastDim(const ShapedType &input, const ShapedType &output,
107108 // following auto_broadcast semantics
108109 const size_t input_rank = inputShape.size ();
109110 const size_t output_rank = outputShape.size ();
110- assert (output_rank >= input_rank &&
111- " Incorrect input or output shape for broadcast op." );
111+ if (output_rank < input_rank)
112+ llvm_unreachable ( " Incorrect input or output shape for broadcast op." );
112113 const size_t offset = output_rank - input_rank;
113114 for (size_t i = 0 ; i < input_rank; ++i) {
114115 if (inputShape[i] == outputShape[i + offset] ||
@@ -390,13 +391,16 @@ mlir::FailureOr<VectorType> getOperationMaxVectorType(Operation *op) {
390391
391392// / select nearest even step
392393int getNearestVectorStep (const int step) {
393- assert (step > 0 );
394+ if (step <= 0 )
395+ llvm_unreachable (" Wrong step." );
396+
394397 int nbits = 0 , n = step;
395398 while (n) {
396399 n = n >> 1 ;
397400 nbits++;
398401 }
399- assert (nbits <= 6 || (nbits == 7 && step == 64 ));
402+ if (nbits > 6 and !(nbits == 7 && step == 64 ))
403+ llvm_unreachable (" wrong nbits appear" );
400404 return (1 << (nbits - 1 )) == step ? step : (1 << nbits);
401405}
402406
@@ -488,7 +492,7 @@ VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loopStep) {
488492 // down into a loop.
489493 mlir::FailureOr<VectorType> baseType = getOperationVectorType (op);
490494 if (failed (baseType)) {
491- assert ( 0 && " Failed to get vector type for operation" );
495+ llvm_unreachable ( " Failed to get vector type for operation" );
492496 return VectorType ();
493497 }
494498 auto vectorizedType = baseType.value ();
@@ -518,7 +522,7 @@ int TypeHelper::generateValidSteps(int steps, VectorType type) {
518522 return favx2bits / typebits;
519523
520524 // invalid hardware
521- assert ( false && " Invalid hardware." );
525+ llvm_unreachable ( " Invalid hardware." );
522526 return -1 ;
523527}
524528
@@ -590,7 +594,8 @@ void GroupOperationFusion::updateGroupBigestVectorType(VectorType vectorType) {
590594}
591595
592596void GroupOperationFusion::addOperationToGroup (Operation *op) {
593- assert (op);
597+ if (not op)
598+ llvm_unreachable (" Op can't be NULL." );
594599 VectorType vectorType = getOperationMaxVectorType (op).value ();
595600 if (isNeedNewGroup (op))
596601 opGroups.emplace_back (std::queue<Operation *>());
0 commit comments