Skip to content

Commit 289459b

Browse files
committed
1 parent 15e6427 commit 289459b

File tree

5 files changed

+22
-15
lines changed

5 files changed

+22
-15
lines changed

external/mlir-hal/lib/Dialect/MHAL/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010

1111
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1212
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
1314
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
1415
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
1516
#include "mlir/Dialect/Func/IR/FuncOps.h"
1617
#include "mlir/Dialect/MHAL/IR/MHAL.h"
1718
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19+
#include "mlir/IR/BuiltinTypes.h"
1820
#include "mlir/IR/Dialect.h"
1921
#include "mlir/IR/Operation.h"
2022
#include "mlir/IR/PatternMatch.h"
@@ -124,11 +126,13 @@ struct LaunchOpInterface
124126
Type returnType = returnVal.getType();
125127
if (isa<TensorType>(returnType)) {
126128
assert(returnType == callResultTypes[funcResultIdx++]);
127-
FailureOr<BaseMemRefType> memrefType =
129+
FailureOr<BufferLikeType> bufferType =
128130
bufferization::getBufferType(returnVal, options, state);
129-
if (failed(memrefType))
131+
if (failed(bufferType))
130132
return failure();
131-
resultTypes.push_back(*memrefType);
133+
assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
134+
BaseMemRefType memrefType = cast<BaseMemRefType>(*bufferType);
135+
resultTypes.push_back(memrefType);
132136
} else {
133137
// Non-tensor values are returned.
134138
resultTypes.push_back(returnType);

mlir/lib/Conversion/EmulateFp8ExtTrunc/EmulateFp8ExtTrunc.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,14 +311,14 @@ static FlatSymbolRefAttr makeFp8TruncFunction(Location loc, FloatType outType,
311311
Value cmp = b.create<CmpIOp>(CmpIPredicate::eq, and4, infNanConst);
312312

313313
Block *notInfNan = func.addBlock();
314-
Value outNan = b.create<ConstantFloatOp>(APFloat::getQNaN(outSem), outType);
314+
Value outNan = b.create<ConstantFloatOp>(outType, APFloat::getQNaN(outSem));
315315
b.create<cf::CondBranchOp>(cmp, ret, ValueRange{outNan}, notInfNan,
316316
ValueRange{});
317317
b.setInsertionPointToStart(notInfNan);
318318

319319
// A deviation from the MIGraphX: denormals are zero here
320320
Value cmp5 = b.create<CmpIOp>(CmpIPredicate::eq, and2, i32Const(0));
321-
Value outZero = b.create<ConstantFloatOp>(APFloat::getZero(outSem), outType);
321+
Value outZero = b.create<ConstantFloatOp>(outType, APFloat::getZero(outSem));
322322
Block *notZero = func.addBlock();
323323
b.create<cf::CondBranchOp>(cmp5, ret, ValueRange{outZero}, notZero,
324324
ValueRange{});
@@ -366,7 +366,7 @@ static FlatSymbolRefAttr makeFp8TruncFunction(Location loc, FloatType outType,
366366
Value cmp57 = b.create<CmpIOp>(CmpIPredicate::ne, sub43, i32Const(0));
367367
Value and58 = b.create<AndIOp>(add56, i32Const(1 << 23));
368368
Value tobool59Not = b.create<CmpIOp>(CmpIPredicate::eq, and58, i32Const(0));
369-
Value trueConst = b.create<ConstantIntOp>(true, 1);
369+
Value trueConst = b.create<ConstantIntOp>(true, /*width=*/1);
370370
Value brCond133 = b.create<SelectOp>(cmp57, trueConst, tobool59Not);
371371

372372
Block *ifElse61 = func.addBlock();
@@ -392,7 +392,7 @@ static FlatSymbolRefAttr makeFp8TruncFunction(Location loc, FloatType outType,
392392
b.setInsertionPointToStart(ifThen70);
393393
Value ir5 = b.create<TruncIOp>(i8, ir1);
394394
Value conv =
395-
b.create<OrIOp>(ir5, b.create<ConstantIntOp>(127, b.getI8Type()));
395+
b.create<OrIOp>(ir5, b.create<ConstantIntOp>(b.getI8Type(), 127));
396396
Value convOut = b.create<BitcastOp>(outType, conv);
397397
b.create<cf::BranchOp>(ret, convOut);
398398

@@ -402,7 +402,7 @@ static FlatSymbolRefAttr makeFp8TruncFunction(Location loc, FloatType outType,
402402
Value cmp72 = b.create<CmpIOp>(CmpIPredicate::eq, f8Exponent0, i32Const(0));
403403
Value cmp74 = b.create<CmpIOp>(CmpIPredicate::ult, mantissa1,
404404
i32Const(1 << (16 + eBits)));
405-
Value falseConst = b.create<ConstantIntOp>(false, 1);
405+
Value falseConst = b.create<ConstantIntOp>(false, /*width=*/1);
406406
Value brCond = b.create<SelectOp>(cmp72, cmp74, falseConst);
407407
b.create<cf::CondBranchOp>(brCond, ret, ValueRange{outZero}, ifEnd76,
408408
ValueRange{f8Exponent0, mantissa1});
@@ -735,7 +735,7 @@ void Fp8TruncToCallPattern::rewrite(TruncFOp op, OpAdaptor adaptor,
735735
Value rets = rewriter.createOrFold<vector::SplatOp>(
736736
loc,
737737
rewriter.createOrFold<ConstantFloatOp>(
738-
loc, APFloat::getZero(outElemType.getFloatSemantics()), outElemType),
738+
loc, outElemType, APFloat::getZero(outElemType.getFloatSemantics())),
739739
retVecType);
740740
SmallVector<int64_t> strides = computeStrides(inVecType.getShape());
741741
for (int64_t i = 0, e = inVecType.getNumElements(); i < e; ++i) {

mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ static Value getConstIntOrIndexValue(OpBuilder &b, Location loc, int64_t value,
916916
if (isa<IndexType>(type)) {
917917
return b.create<ConstantIndexOp>(loc, value);
918918
}
919-
return b.create<ConstantIntOp>(loc, value, type);
919+
return b.create<ConstantIntOp>(loc, type, value);
920920
}
921921

922922
// Manually flatten a set of coordinates into a single address

mlir/lib/Dialect/Rock/utility/builderUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ Value getAsTensor(OpBuilder &builder, Location loc, mlir::Value value,
209209
bool isWritable) {
210210
constexpr bool isRestrict{true};
211211
Value origTensor = builder.create<bufferization::ToTensorOp>(
212-
loc, value, isRestrict, isWritable);
212+
loc, value.getType(), value, isRestrict, isWritable);
213213
return origTensor;
214214
}
215215

mlir/tools/rocmlir-gen/rocmlir-gen.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3280,7 +3280,8 @@ createCpuConvElementwiseGemmKernelWithMlir(ModuleOp module,
32803280
bool isWritable = false) {
32813281
constexpr bool isRestrict{true};
32823282
Value flatTensor = builder.create<bufferization::ToTensorOp>(
3283-
loc, block->getArgument(blockArgIndex), isRestrict, isWritable);
3283+
loc, block->getArgument(blockArgIndex).getType(),
3284+
block->getArgument(blockArgIndex), isRestrict, isWritable);
32843285
ArrayRef<int64_t> origShape =
32853286
cast<ShapedType>(argTypes[blockArgIndex]).getShape();
32863287

@@ -3459,7 +3460,8 @@ createCpuGemmElementwiseGemmKernelWithMlir(ModuleOp module,
34593460
bool isWritable = false) {
34603461
constexpr bool isRestrict{true};
34613462
Value flatTensor = builder.create<bufferization::ToTensorOp>(
3462-
loc, block->getArgument(blockArgIndex), isRestrict, isWritable);
3463+
loc, block->getArgument(blockArgIndex).getType(),
3464+
block->getArgument(blockArgIndex), isRestrict, isWritable);
34633465
ArrayRef<int64_t> origShape =
34643466
cast<ShapedType>(argTypes[blockArgIndex]).getShape();
34653467

@@ -3574,7 +3576,8 @@ static func::FuncOp createCpuAttentionKernelWithMlir(ModuleOp module,
35743576
bool isWritable = false) {
35753577
constexpr bool isRestrict{true};
35763578
Value flatTensor = builder.create<bufferization::ToTensorOp>(
3577-
loc, block->getArgument(blockArgIndex), isRestrict, isWritable);
3579+
loc, block->getArgument(blockArgIndex).getType(),
3580+
block->getArgument(blockArgIndex), isRestrict, isWritable);
35783581
ArrayRef<int64_t> origShape =
35793582
cast<ShapedType>(argTypes[blockArgIndex]).getShape();
35803583

@@ -3935,7 +3938,7 @@ static func::FuncOp createVerifierFunc(ModuleOp module, const KernelIF &kernel,
39353938
char printDebug = static_cast<char>(printVerifyResults.getValue());
39363939

39373940
auto printDebugVal =
3938-
b.create<arith::ConstantIntOp>(loc, printDebug, charType);
3941+
b.create<arith::ConstantIntOp>(loc, charType, printDebug);
39393942

39403943
// obtain function name of the verifier wrapper
39413944
std::string verifyFuncName = "mcpuVerify";

0 commit comments

Comments
 (0)