Skip to content

Commit ec5f5e6

Browse files
committed
Update conv_acc helper function with changes in torch-mlir
1 parent 4ed6347 commit ec5f5e6

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,11 @@ Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
184184

185185
// AMD: Picked from torch-mlir 12250739bfe85b702f9503cad45c2e535ea8eb18
186186
// Get accumulator type for TOSA convolution ops
187-
LogicalResult mlir::tosa ::getConvOpsAccType(PatternRewriter &rewriter,
188-
RankedTensorType inputTy,
189-
RankedTensorType weightTy,
190-
RankedTensorType outputTy,
191-
TypeAttr &accType) {
187+
LogicalResult mlir::tosa::getConvOpsAccType(PatternRewriter &rewriter,
188+
RankedTensorType inputTy,
189+
RankedTensorType weightTy,
190+
RankedTensorType outputTy,
191+
TypeAttr &accType) {
192192
auto inputElemTy = inputTy.getElementType();
193193
auto weightElemTy = weightTy.getElementType();
194194
auto outputElemTy = outputTy.getElementType();
@@ -218,8 +218,8 @@ LogicalResult mlir::tosa ::getConvOpsAccType(PatternRewriter &rewriter,
218218
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
219219
outputElemTy.isInteger(48)) {
220220
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
221-
} else if ((isa<Float8E4M3FNType>(inputElemTy) &&
222-
isa<Float8E4M3FNType>(weightElemTy) && outputElemTy.isF16()) ||
221+
} else if ((isa<Float8E4M3Type>(inputElemTy) &&
222+
isa<Float8E4M3Type>(weightElemTy) && outputElemTy.isF16()) ||
223223
(isa<Float8E5M2Type>(inputElemTy) &&
224224
isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16())) {
225225
accType = mlir::TypeAttr::get(rewriter.getF16Type());

0 commit comments

Comments
 (0)