Skip to content

Commit a79147c

Browse files
author
Xu, Xiaohui1
committed
add utils.cpp
1 parent 1f0e3ce commit a79147c

File tree

6 files changed

+309
-245
lines changed

6 files changed

+309
-245
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
//===-- VectorUtils.h ----- vector fusion analysis --------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_TRANSFORMS_UTILS_VECTORUTILS_H
10+
#define GC_TRANSFORMS_UTILS_VECTORUTILS_H
11+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
12+
#include "mlir/IR/BuiltinTypes.h"
13+
#include "mlir/IR/TypeUtilities.h"
14+
#include <limits>
15+
#include <stdint.h>
16+
#include <variant>
17+
18+
namespace mlir {
19+
namespace gc {
20+
union Float32Bits {
21+
uint32_t u;
22+
float f;
23+
};
24+
uint16_t float2half(float floatValue);
25+
float half2float(uint16_t halfValue);
26+
uint16_t float2bfloat(float floatValue);
27+
float bfloat2float(uint16_t bfloatBits);
28+
std::variant<float, int64_t> numeric_limits_minimum(Type type);
29+
std::variant<float, int64_t> numericLimitsMaximum(Type type);
30+
31+
template <typename T = float>
32+
T getInitValForReduce(vector::CombiningKind kind, Type t) {
33+
T result;
34+
Type t1 = getElementTypeOrSelf(t);
35+
36+
switch (kind) {
37+
case vector::CombiningKind::ADD:
38+
if (t1.isIntOrIndex())
39+
result = 0;
40+
else if (isa<FloatType>(t1))
41+
result = 0.0f;
42+
else
43+
llvm_unreachable("invalid value types for ADD reduction");
44+
break;
45+
case vector::CombiningKind::MAXNUMF:
46+
case vector::CombiningKind::MAXIMUMF:
47+
if (not isa<FloatType>(t1))
48+
llvm_unreachable("Expected float values.");
49+
result = std::get<T>(numeric_limits_minimum(t));
50+
break;
51+
case vector::CombiningKind::MINNUMF:
52+
case vector::CombiningKind::MINIMUMF:
53+
if (not isa<FloatType>(t1))
54+
llvm_unreachable("Expected float values.");
55+
result = std::get<T>(numericLimitsMaximum(t));
56+
break;
57+
case vector::CombiningKind::MAXSI:
58+
case vector::CombiningKind::MAXUI:
59+
if (not t1.isIntOrIndex())
60+
llvm_unreachable("Expected int or index values.");
61+
result = std::get<T>(numeric_limits_minimum(t));
62+
break;
63+
case vector::CombiningKind::MINSI:
64+
case vector::CombiningKind::MINUI:
65+
if (not t1.isIntOrIndex())
66+
llvm_unreachable("Expected int or index values.");
67+
result = std::get<T>(numericLimitsMaximum(t));
68+
break;
69+
case vector::CombiningKind::MUL:
70+
if (t1.isIntOrIndex())
71+
result = 1;
72+
else if (isa<FloatType>(t1))
73+
result = 1.f;
74+
else
75+
llvm_unreachable("invalid value types for MUL reduction");
76+
break;
77+
default:
78+
llvm_unreachable("unsupported reduction kind");
79+
};
80+
return result;
81+
}
82+
83+
} // namespace gc
84+
} // namespace mlir
85+
86+
#endif

lib/gc/Analysis/VectorBasedFusionAnalysis.cpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88
#include "gc/Analysis/VectorBasedFusionAnalysis.h"
9-
#include "mlir/Dialect/Linalg/IR/Linalg.h"
9+
#include "gc/Dialect/Linalgx/Utils.h"
1010

1111
namespace mlir {
1212
namespace gc {
@@ -22,16 +22,16 @@ namespace gc {
2222
arith::TruncFOp, arith::TruncIOp
2323

2424
#define NOT_NEED_TO_PROCESS_OP \
25-
linalg::GenericOp, linalg::BatchReduceMatmulOp, linalg::MatmulOp, \
26-
linalg::BatchMatmulOp, linalg::BatchMatmulTransposeAOp, \
27-
linalg::BatchMatmulTransposeBOp, linalg::MatmulTransposeAOp, \
28-
linalg::MatmulTransposeBOp, linalg::QuantizedBatchMatmulOp, \
29-
linalg::QuantizedMatmulOp, tensor::CollapseShapeOp, \
30-
tensor::ExpandShapeOp, tensor::ExtractSliceOp, tensor::InsertSliceOp, \
31-
microkernel::BrgemmOp
25+
linalg::BatchReduceMatmulOp, linalg::MatmulOp, linalg::BatchMatmulOp, \
26+
linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \
27+
linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \
28+
linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \
29+
tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::ExtractSliceOp, \
30+
tensor::InsertSliceOp, microkernel::BrgemmOp
3231

3332
static inline bool isNotNeedToProcessOp(Operation *op) {
34-
return isa<NOT_NEED_TO_PROCESS_OP>(op);
33+
return isa<NOT_NEED_TO_PROCESS_OP>(op) or
34+
linalgx::isAnyGenericPackedMatmulOp(op);
3535
}
3636

3737
static inline bool isSpecialOp(Operation *op) {
@@ -72,7 +72,7 @@ void shapeCastSourceAxis(const ArrayRef<int64_t> &a, const ArrayRef<int64_t> &b,
7272
while (dimB < dimA && j < rankB)
7373
dimB *= b[j++];
7474
if (dimA != dimB) {
75-
assert(false && " Invalid shape cast operation.");
75+
llvm::llvm_unreachable_internal(" Invalid shape cast operation.");
7676
break;
7777
}
7878
if (bAxisBegin != j) {
@@ -87,12 +87,13 @@ void shapeCastSourceAxis(const ArrayRef<int64_t> &a, const ArrayRef<int64_t> &b,
8787
if (j < rankB && all_of(b.slice(j), isOne))
8888
j = rankB;
8989
}
90-
91-
assert(i == rankA && j == rankB && "Invalid shapecast operation.");
90+
if (i != rankA or j != rankB)
91+
llvm_unreachable("Invalid shapecast operation.");
9292
}
9393

9494
bool isScalar(Type type) {
95-
assert(type && "Not a valid type");
95+
if (not type)
96+
llvm_unreachable("Not a valid type");
9697
if (auto vecType = dyn_cast<VectorType>(type))
9798
return false;
9899
if (auto tensorType = dyn_cast<TensorType>(type))
@@ -107,8 +108,8 @@ void getSrcBroadcastDim(const ShapedType &input, const ShapedType &output,
107108
// following auto_broadcast semantics
108109
const size_t input_rank = inputShape.size();
109110
const size_t output_rank = outputShape.size();
110-
assert(output_rank >= input_rank &&
111-
"Incorrect input or output shape for broadcast op.");
111+
if (output_rank < input_rank)
112+
llvm_unreachable("Incorrect input or output shape for broadcast op.");
112113
const size_t offset = output_rank - input_rank;
113114
for (size_t i = 0; i < input_rank; ++i) {
114115
if (inputShape[i] == outputShape[i + offset] ||
@@ -390,13 +391,16 @@ mlir::FailureOr<VectorType> getOperationMaxVectorType(Operation *op) {
390391

391392
/// select nearest even step
392393
int getNearestVectorStep(const int step) {
393-
assert(step > 0);
394+
if (step <= 0)
395+
llvm_unreachable("Wrong step.");
396+
394397
int nbits = 0, n = step;
395398
while (n) {
396399
n = n >> 1;
397400
nbits++;
398401
}
399-
assert(nbits <= 6 || (nbits == 7 && step == 64));
402+
if (nbits > 6 and !(nbits == 7 && step == 64))
403+
llvm_unreachable("wrong nbits appear");
400404
return (1 << (nbits - 1)) == step ? step : (1 << nbits);
401405
}
402406

@@ -488,7 +492,7 @@ VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loopStep) {
488492
// down into a loop.
489493
mlir::FailureOr<VectorType> baseType = getOperationVectorType(op);
490494
if (failed(baseType)) {
491-
assert(0 && "Failed to get vector type for operation");
495+
llvm_unreachable("Failed to get vector type for operation");
492496
return VectorType();
493497
}
494498
auto vectorizedType = baseType.value();
@@ -518,7 +522,7 @@ int TypeHelper::generateValidSteps(int steps, VectorType type) {
518522
return favx2bits / typebits;
519523

520524
// invalid hardware
521-
assert(false && "Invalid hardware.");
525+
llvm_unreachable("Invalid hardware.");
522526
return -1;
523527
}
524528

@@ -590,7 +594,8 @@ void GroupOperationFusion::updateGroupBigestVectorType(VectorType vectorType) {
590594
}
591595

592596
void GroupOperationFusion::addOperationToGroup(Operation *op) {
593-
assert(op);
597+
if (not op)
598+
llvm_unreachable("Op can't be NULL.");
594599
VectorType vectorType = getOperationMaxVectorType(op).value();
595600
if (isNeedNewGroup(op))
596601
opGroups.emplace_back(std::queue<Operation *>());

0 commit comments

Comments
 (0)