@@ -184,11 +184,11 @@ Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
184184
185185// AMD: Picked from torch-mlir 12250739bfe85b702f9503cad45c2e535ea8eb18
186186// Get accumulator type for TOSA convolution ops
187- LogicalResult mlir::tosa ::getConvOpsAccType (PatternRewriter &rewriter,
188- RankedTensorType inputTy,
189- RankedTensorType weightTy,
190- RankedTensorType outputTy,
191- TypeAttr &accType) {
187+ LogicalResult mlir::tosa::getConvOpsAccType (PatternRewriter &rewriter,
188+ RankedTensorType inputTy,
189+ RankedTensorType weightTy,
190+ RankedTensorType outputTy,
191+ TypeAttr &accType) {
192192 auto inputElemTy = inputTy.getElementType ();
193193 auto weightElemTy = weightTy.getElementType ();
194194 auto outputElemTy = outputTy.getElementType ();
@@ -218,8 +218,8 @@ LogicalResult mlir::tosa ::getConvOpsAccType(PatternRewriter &rewriter,
218218 } else if (inputElemTy.isInteger (16 ) && weightElemTy.isInteger (8 ) &&
219219 outputElemTy.isInteger (48 )) {
220220 accType = mlir::TypeAttr::get (rewriter.getIntegerType (48 ));
221- } else if ((isa<Float8E4M3FNType >(inputElemTy) &&
222- isa<Float8E4M3FNType >(weightElemTy) && outputElemTy.isF16 ()) ||
221+ } else if ((isa<Float8E4M3Type >(inputElemTy) &&
222+ isa<Float8E4M3Type >(weightElemTy) && outputElemTy.isF16 ()) ||
223223 (isa<Float8E5M2Type>(inputElemTy) &&
224224 isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16 ())) {
225225 accType = mlir::TypeAttr::get (rewriter.getF16Type ());
0 commit comments