Skip to content

Commit 113bbb1

Browse files
Removing unnecessary tests and changes, and fixing casts from vector to scalar
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 814d53b commit 113bbb1

File tree

3 files changed

+170
-171
lines changed

3 files changed

+170
-171
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,9 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
243243
Type operandETy = getElementTypeOrSelf(operandTy);
244244
Type resultETy = getElementTypeOrSelf(resultTy);
245245

246-
if (!operandETy.isBF16() || !resultETy.isF32())
246+
if (!operandETy.isBF16() || !resultETy.isF32()) {
247247
return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
248+
}
248249

249250
Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
250251
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
@@ -272,8 +273,9 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
272273
Type operandETy = getElementTypeOrSelf(operandTy);
273274
Type resultETy = getElementTypeOrSelf(resultTy);
274275

275-
if (!operandETy.isF32() || !resultETy.isBF16())
276+
if (!operandETy.isF32() || !resultETy.isBF16()) {
276277
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
278+
}
277279

278280
if (op.getRoundingmodeAttr()) {
279281
return rewriter.notifyMatchFailure(
@@ -422,7 +424,7 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
422424
Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32);
423425
Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32);
424426
if (!isa<Float32Type>(resultETy))
425-
result = b.create<arith::TruncFOp>(resultETy, operand);
427+
result = b.create<arith::TruncFOp>(resultTy, result);
426428

427429
rewriter.replaceOp(op, result);
428430
return success();
@@ -440,8 +442,9 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
440442
Type operandETy = getElementTypeOrSelf(operandTy);
441443
Type resultETy = getElementTypeOrSelf(resultTy);
442444

443-
if (!llvm::isa<Float8E8M0FNUType>(operandETy))
445+
if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
444446
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
447+
}
445448

446449
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
447450
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
@@ -512,16 +515,16 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
512515
Type operandETy = getElementTypeOrSelf(operandTy);
513516
Type resultETy = getElementTypeOrSelf(resultTy);
514517

515-
if (!isa<Float32Type>(operandETy))
516-
operand = b.create<arith::ExtFOp>(b.getF32Type(), operand);
517-
if (!isa<Float4E2M1FNType>(resultETy))
518-
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
519-
520518
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
521519
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
522520
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
523521
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
524522

523+
if (!isa<Float32Type>(operandETy))
524+
operand = b.create<arith::ExtFOp>(f32Ty, operand);
525+
if (!isa<Float4E2M1FNType>(resultETy))
526+
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
527+
525528
Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
526529
Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
527530
Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
@@ -611,12 +614,14 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
611614
Type operandETy = getElementTypeOrSelf(operandTy);
612615
Type resultTy = op.getType();
613616
Type resultETy = getElementTypeOrSelf(resultTy);
614-
if (!llvm::isa<Float8E8M0FNUType>(resultETy))
617+
if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
615618
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
619+
}
616620

617-
if (op.getRoundingmodeAttr())
621+
if (op.getRoundingmodeAttr()) {
618622
return rewriter.notifyMatchFailure(
619623
op, "only applicable to default rounding mode.");
624+
}
620625

621626
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
622627
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());

mlir/test/Dialect/Arith/expand-ops-scale.mlir

Lines changed: 0 additions & 159 deletions
This file was deleted.

0 commit comments

Comments
 (0)