Skip to content

Commit 81a1ff8

Browse files
christopherlmunozjorickert
authored andcommitted
AMD extension: Bump to e240261
Upgrading llvm and stablehlo hash (onnx#3053) * upgrading llvm and stablehlo hash. Fixing mlir tests Signed-off-by: Christopher Munoz <[email protected]> * fixing vector shapecast bug introduced by upgraded llvm Signed-off-by: Christopher Munoz <[email protected]> --------- Signed-off-by: Christopher Munoz <[email protected]>
1 parent 6145f33 commit 81a1ff8

File tree

28 files changed

+242
-160
lines changed

28 files changed

+242
-160
lines changed

docs/BuildOnLinuxOSX.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
1515
``` bash
1616
git clone -n https://github.com/llvm/llvm-project.git
1717
# Check out a specific branch that is known to work with ONNX-MLIR.
18-
cd llvm-project && git checkout e8be3bea2ce0ec51b614cd7eb7d5d3a1e56d9524 && cd ..
18+
cd llvm-project && git checkout bb99503626b6efd2bd87a216ff279181cc6ec48f && cd ..
1919
```
2020

2121
[same-as-file]: <> (utils/build-mlir.sh)

docs/BuildOnWindows.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
5252
```shell
5353
git clone -n https://github.com/llvm/llvm-project.git
5454
# Check out a specific branch that is known to work with ONNX-MLIR.
55-
cd llvm-project && git checkout e8be3bea2ce0ec51b614cd7eb7d5d3a1e56d9524 && cd ..
55+
cd llvm-project && git checkout bb99503626b6efd2bd87a216ff279181cc6ec48f && cd ..
5656
```
5757

5858
[same-as-file]: <> (utils/build-mlir.cmd)

src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ ApiRegistry RegisterAllApis(MLIRContext *context) {
3535
auto int16Ty = IntegerType::get(context, 16);
3636
auto int32Ty = IntegerType::get(context, 32);
3737
auto int64Ty = IntegerType::get(context, 64);
38-
auto float32Ty = FloatType::getF32(context);
38+
auto float32Ty = Float32Type::get(context);
3939

4040
// Declare API type as an enum value, its string name and an LLVM Type
4141
// specifying its signature.
@@ -570,7 +570,7 @@ Type getZTensorStructTy(MLIRContext *context) {
570570
Type llvmI64Ty = IntegerType::get(context, 64);
571571
Type llvmI1Ty = IntegerType::get(context, 1);
572572
Type llvmI8Ty = IntegerType::get(context, 8);
573-
Type llvmF32Ty = FloatType::getF32(context);
573+
Type llvmF32Ty = Float32Type::get(context);
574574
Type llvmArray3I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 3);
575575
Type llvmArray20I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 20);
576576
Type llvmI8PtrTy = krnl::getPointerType(context, llvmI8Ty);
@@ -662,7 +662,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
662662
scaleTy.isF32() && "Wrong type for zTensor's rec_scale. Must be float");
663663
create.llvm.store(recScale, recScalePtr);
664664
} else {
665-
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
665+
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
666666
create.llvm.store(zero, recScalePtr);
667667
}
668668

@@ -675,7 +675,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
675675
offsetTy.isF32() && "Wrong type for zTensor's offset. Must be float");
676676
create.llvm.store(offset, offsetPtr);
677677
} else {
678-
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
678+
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
679679
create.llvm.store(zero, offsetPtr);
680680
}
681681

src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,23 @@ void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns,
193193
// They run it in two steps, and add additional lowerings.
194194

195195
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
196+
vector::populateVectorBitCastLoweringPatterns(patterns);
196197
vector::populateVectorBroadcastLoweringPatterns(patterns);
197198
vector::populateVectorContractLoweringPatterns(
198199
patterns, vector::VectorTransformsOptions());
200+
vector::populateVectorMaskOpLoweringPatterns(patterns);
201+
vector::populateVectorShapeCastLoweringPatterns(patterns);
202+
vector::populateVectorInterleaveLoweringPatterns(patterns);
199203
vector::populateVectorTransposeLoweringPatterns(
200204
patterns, vector::VectorTransformsOptions());
201-
vector::populateVectorShapeCastLoweringPatterns(patterns);
205+
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
206+
vector::populateVectorTransferLoweringPatterns(
207+
patterns, /*maxTransferRank=*/1);
208+
vector::populateVectorMaskMaterializationPatterns(
209+
patterns, /*force32BitVectorIndices*/ false);
210+
vector::populateVectorInsertExtractStridedSliceTransforms(patterns);
211+
vector::populateVectorStepLoweringPatterns(patterns);
212+
vector::populateVectorRankReducingFMAPattern(patterns);
202213

203214
populateAffineToStdConversionPatterns(patterns);
204215
populateSCFToControlFlowConversionPatterns(patterns);

src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ class KrnlRandomNormalOpLowering : public ConversionPattern {
8080
// or
8181
// (memref<3x4x5xf64>, index, f64, f64, f64)
8282
Type llvmVoidTy = LLVM::LLVMVoidType::get(context);
83-
Type llvmOptionsTy = FloatType::getF32(context);
83+
Type llvmOptionsTy = Float32Type::get(context);
8484
Type llvmOutputTy = getPointerType(context, llvmOptionsTy);
8585
if (inType.isF64()) {
86-
llvmOptionsTy = FloatType::getF64(context);
86+
llvmOptionsTy = Float64Type::get(context);
8787
llvmOutputTy = getPointerType(context, llvmOptionsTy);
8888
}
8989
Type llvmI64Ty = IntegerType::get(context, 64);

src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,19 +172,19 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
172172
Type outType = op->getResultTypes().front();
173173
Type llvmInType, llvmOutType;
174174
if (inType.isF16())
175-
llvmInType = FloatType::getF16(context);
175+
llvmInType = Float16Type::get(context);
176176
else if (inType.isF32())
177-
llvmInType = FloatType::getF32(context);
177+
llvmInType = Float32Type::get(context);
178178
else if (inType.isF64())
179-
llvmInType = FloatType::getF64(context);
179+
llvmInType = Float64Type::get(context);
180180
else if (inType.isBF16())
181-
llvmInType = FloatType::getBF16(context);
181+
llvmInType = BFloat16Type::get(context);
182182
if (outType.isInteger(1))
183183
llvmOutType = IntegerType::get(context, 1);
184184
else if (outType.isF32())
185-
llvmOutType = FloatType::getF32(context);
185+
llvmOutType = Float32Type::get(context);
186186
else if (outType.isF64())
187-
llvmOutType = FloatType::getF64(context);
187+
llvmOutType = Float64Type::get(context);
188188

189189
// Insert and/or get reference to elementary math function declaration.
190190
assert(

src/Conversion/ONNXToKrnl/Math/LRN.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct ONNXLRNOpLowering : public OpConversionPattern<ONNXLRNOp> {
5252
float alphaLit = adaptor.getAlpha().convertToFloat();
5353
float betaLit = adaptor.getBeta().convertToFloat();
5454
int sizeLit = adaptor.getSize();
55-
auto f32Type = FloatType::getF32(rewriter.getContext());
55+
auto f32Type = Float32Type::get(rewriter.getContext());
5656
Value biasValue = create.math.constant(f32Type, biasLit);
5757
Value alphaDivSizeValue =
5858
create.math.constant(f32Type, alphaLit / static_cast<float>(sizeLit));

src/Conversion/ONNXToTOSA/Math/Conv2D.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
3434
const llvm::ArrayRef<int64_t> weightShape, Value &newInput,
3535
Value &newWeight, Value &bias, const int64_t groups,
3636
DenseI64ArrayAttr &pads, DenseI64ArrayAttr &strides,
37-
DenseI64ArrayAttr &dilations) {
37+
DenseI64ArrayAttr &dilations, TypeAttr &accType) {
3838
// Set up constants outside of loop
3939
const int64_t sizeOfSliceInput = weightShape[1];
4040
const int64_t sizeOfSliceKernel = weightShape[0] / groups;
@@ -65,7 +65,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
6565
mlir::cast<ShapedType>(resultType).getElementType());
6666
Value tempConv2D = tosa::CreateOpAndInfer<mlir::tosa::Conv2DOp>(rewriter,
6767
op->getLoc(), newConvOutputType, newSliceInput, newSliceWeight,
68-
newSliceBias, pads, strides, dilations);
68+
newSliceBias, pads, strides, dilations, accType);
6969
// Add value to vector
7070
sliceValues.push_back(tempConv2D);
7171
}
@@ -156,6 +156,10 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern {
156156

157157
DenseI64ArrayAttr newPads = rewriter.getDenseI64ArrayAttr(reorderedPads);
158158

159+
Type convType =
160+
(resultType.isF16()) ? rewriter.getF16Type() : rewriter.getF32Type();
161+
TypeAttr accType = mlir::TypeAttr::get(convType);
162+
159163
// Handle group parameter by creating multiple convs
160164
const int64_t group = adaptor.getGroup();
161165
Value conv2D = NULL;
@@ -166,10 +170,10 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern {
166170

167171
conv2D = tosa::CreateOpAndInfer<mlir::tosa::Conv2DOp>(rewriter,
168172
convOp->getLoc(), newConvOutputType, newInput, newWeight, bias,
169-
newPads, strides, dilations);
173+
newPads, strides, dilations, accType);
170174
} else {
171175
auto inputChannels = inputType.getDimSize(1);
172-
auto outputChannels = resultType.cast<ShapedType>().getDimSize(1);
176+
auto outputChannels = cast<ShapedType>(resultType).getDimSize(1);
173177
if (group == inputChannels && (outputChannels % inputChannels == 0)) {
174178
// If the group == inputChannels and
175179
// outputChannels == inputChannels * integerNumber,
@@ -185,19 +189,19 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern {
185189

186190
Type newConvOutputType = RankedTensorType::get(
187191
llvm::SmallVector<int64_t, 4>(4, ShapedType::kDynamic),
188-
resultType.cast<ShapedType>().getElementType());
192+
cast<ShapedType>(resultType).getElementType());
189193

190194
conv2D = tosa::CreateOpAndInfer<mlir::tosa::DepthwiseConv2DOp>(rewriter,
191195
convOp->getLoc(), newConvOutputType, newInput, newWeight, bias,
192-
newPads, strides, dilations);
196+
newPads, strides, dilations, accType);
193197
} else if (group <= groupedConvThreshold) {
194198
// Decompose group convolution into a concatenation of tosa.conv2d ops
195199
// can be costly, so only allow it when the number of groups is less
196200
// than configurable threshold.
197201

198202
conv2D = createConvInGroups(rewriter, convOp, tosaBuilder, resultType,
199203
weightShape, newInput, newWeight, bias, group, newPads, strides,
200-
dilations);
204+
dilations, accType);
201205
} else {
202206
return rewriter.notifyMatchFailure(
203207
op, "this type of grouped Conv is not supported");

src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ mlir::Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter,
114114

115115
// Create a new pad vec in the right format
116116
// ONNX : [b1, b2, b3, b4, e1, e2, e3, e4]
117-
// TOSA :[[b1, e1], [b2, e2], [b3, e3], [b4, e4]]
117+
// TOSA :[b1, e1, b2, e2, b3, e3, b4, e4]
118118

119119
// Adds any initial or last vals, not included in onnxPads.
120120
llvm::SmallVector<int64_t, 8> tosaPads{initialVals};
@@ -125,11 +125,9 @@ mlir::Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter,
125125
tosaPads.push_back(onnxPads[i + dimSize]);
126126
}
127127
tosaPads.insert(tosaPads.end(), lastVals.begin(), lastVals.end());
128-
129-
// TOSA format groups dimensions by 2.
130-
const unsigned int numberOfDims = tosaPads.size() / 2;
131128
TosaBuilder tosaBuilder(rewriter, loc);
132-
return tosaBuilder.getConst(tosaPads, {numberOfDims, 2});
129+
return tosaBuilder.getConst(
130+
tosaPads, {static_cast<int64_t>(tosaPads.size())});
133131
}
134132

135133
mlir::Value expandShape(mlir::PatternRewriter &rewriter, mlir::Location loc,

src/Conversion/ONNXToTOSA/Tensor/Expand.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <mlir/IR/BuiltinTypeInterfaces.h>
2626
#include <mlir/IR/BuiltinTypes.h>
2727
#include <mlir/Transforms/DialectConversion.h>
28+
#include <mlir/Dialect/Tosa/Utils/ConversionUtils.h>
2829

2930
#include <llvm/ADT/SmallVector.h>
3031
#include <llvm/Support/Casting.h>
@@ -86,8 +87,8 @@ class ONNXExpandLoweringToTOSA : public OpConversionPattern<ONNXExpandOp> {
8687
newInput = tosaBuilder.reshape(adaptor.getInput(), newShape);
8788
}
8889

89-
auto denseShape =
90-
getMultiplies(op, cast<RankedTensorType>(newInput.getType()).getShape(),
90+
const auto multiplies =
91+
getMultiplies(cast<RankedTensorType>(newInput.getType()).getShape(),
9192
outputType.getShape());
9293
auto resultElementType = cast<RankedTensorType>(inputType).getElementType();
9394
auto newResultElementType =
@@ -101,8 +102,9 @@ class ONNXExpandLoweringToTOSA : public OpConversionPattern<ONNXExpandOp> {
101102
llvm::SmallVector<int64_t>(
102103
outputType.getShape().size(), ShapedType::kDynamic),
103104
newResultElementType);
104-
onnx_mlir::tosa::CreateReplaceOpAndInfer<mlir::tosa::TileOp>(
105-
rewriter, op, newTileOutputType, newInput, denseShape);
105+
onnx_mlir::tosa::CreateReplaceOpAndInfer<mlir::tosa::TileOp>(rewriter, op,
106+
newTileOutputType, newInput,
107+
mlir::tosa::getTosaConstShape(rewriter, op->getLoc(), multiplies));
106108
return success();
107109
}
108110

@@ -148,7 +150,7 @@ class ONNXExpandLoweringToTOSA : public OpConversionPattern<ONNXExpandOp> {
148150
return result;
149151
}
150152

151-
static DenseI64ArrayAttr getMultiplies(ONNXExpandOp &op,
153+
static llvm::SmallVector<int64_t> getMultiplies(
152154
const llvm::ArrayRef<int64_t> &inputShape,
153155
const llvm::ArrayRef<int64_t> &outputShape) {
154156
llvm::SmallVector<int64_t> multipliesArray;
@@ -159,7 +161,7 @@ class ONNXExpandLoweringToTOSA : public OpConversionPattern<ONNXExpandOp> {
159161
multipliesArray.push_back(outputShape[i] / inputShape[i]);
160162
}
161163
}
162-
return DenseI64ArrayAttr::get(op.getContext(), multipliesArray);
164+
return multipliesArray;
163165
}
164166
};
165167

0 commit comments

Comments
 (0)