Skip to content

Commit b3c4977

Browse files
add tests and fix various issues revealed by tests
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 2c20ce6 commit b3c4977

File tree

6 files changed

+333
-176
lines changed

6 files changed

+333
-176
lines changed

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
5959
/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
6060
void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
6161

62+
/// Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
63+
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns);
64+
6265
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
6366
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
6467

mlir/include/mlir/Dialect/Arith/Transforms/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def ArithExpandOpsPass : Pass<"arith-expand"> {
1919
"Enable the BF16 expansion patterns">,
2020
Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
2121
"Enable the F8E8M0 expansion patterns">,
22+
Option<"includeF4E2M1", "include-f4e2m1", "bool", /*default=*/"false",
23+
"Enable the F4E2M1 expansion patterns">,
2224
];
2325
}
2426

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,8 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
345345
Type operandETy = getElementTypeOrSelf(operandTy);
346346
Type resultETy = getElementTypeOrSelf(resultTy);
347347

348-
if (!llvm::isa<Float4E2M1FNType>(operandETy) ||
349-
!llvm::isa<Float32Type>(resultETy)) {
350-
return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN to F32");
348+
if (!isa<Float4E2M1FNType>(operandETy)) {
349+
return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
351350
}
352351

353352
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
@@ -357,8 +356,9 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
357356
Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
358357

359358
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
360-
Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
361359
Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
360+
Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
361+
Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
362362
Value cZero =
363363
createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
364364
Value cHalf =
@@ -370,29 +370,33 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
370370

371371
Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
372372
Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
373-
f32Bits = b.create<arith::ShRUIOp>(f32Bits, c0x0000001c);
373+
f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
374374

375375
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
376376
Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
377377
f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
378378
Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
379379
f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
380-
f32ExpBits = b.create<arith::ShLIOp>(f32ExpBits, c0x00000014);
381-
f32Bits = b.create<arith::AddIOp>(f32Bits, f32ExpBits);
380+
Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
381+
f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
382382

383383
Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
384384
Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
385+
f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
385386
f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
386387

387-
// Special consideration for subnormal exp (exp == 0).
388+
// Special consideration for subnormal exponent (exp == 00).
388389
Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
389390
f32ExpBits, biasAdjustment);
390391
Value isManSet =
391392
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
392393
Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
393-
f32Bits = b.create<arith::SelectOp>(isSubnormal, subnormalVal, f32Bits);
394394

395395
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
396+
result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
397+
if (!isa<Float32Type>(resultETy)) {
398+
result = b.create<arith::TruncFOp>(resultETy, operand);
399+
}
396400
rewriter.replaceOp(op, result);
397401
return success();
398402
}
@@ -481,8 +485,11 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
481485
Type operandETy = getElementTypeOrSelf(operandTy);
482486
Type resultETy = getElementTypeOrSelf(resultTy);
483487

484-
if (!isa<Float32Type>(operandETy) || !isa<Float4E2M1FNType>(resultETy)) {
485-
return rewriter.notifyMatchFailure(op, "not a trunc of F32 to F4E2M1FN");
488+
if (!isa<Float32Type>(operandETy)) {
489+
operand = b.create<arith::ExtFOp>(b.getF32Type(), operand);
490+
}
491+
if (!isa<Float4E2M1FNType>(resultETy)) {
492+
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
486493
}
487494

488495
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
@@ -491,20 +498,28 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
491498
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
492499

493500
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
501+
Value c0x3 = createConst(op->getLoc(), i4Ty, 3, rewriter);
494502
Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
495503
Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
496504
Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
497505
Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
498506

499-
// Step 1: Clamp to bounds.
507+
// Step 0: Clamp to bounds.
500508
Value cHigherBound =
501509
createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
502510
Value cLowerBound =
503511
createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
504-
Value operandClamped = b.create<arith::MinimumFOp>(cLowerBound, operand);
505-
operandClamped = b.create<arith::MaximumFOp>(cHigherBound, operandClamped);
512+
Value operandClamped = b.create<arith::MinimumFOp>(cHigherBound, operand);
513+
operandClamped = b.create<arith::MaximumFOp>(cLowerBound, operandClamped);
506514
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
507515

516+
// Step 1: Set sign bit.
517+
Value cF32ExpManWidth =
518+
createConst(op->getLoc(), i32Ty, 31, rewriter); // 23
519+
Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
520+
Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
521+
Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
522+
508523
// Step 2: Convert exponent by adjusting bias.
509524
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
510525
Value cF4MantissaWidth = c0x1; // 1
@@ -513,16 +528,17 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
513528
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
514529
Value biasAdjustedSignExp =
515530
b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
516-
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
517-
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
531+
Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
532+
f4Exp = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
533+
f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
518534

519535
// Step 3: Set mantissa to first bit.
520536
Value cF32FirstBitMask =
521537
createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
522538
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
523539
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
524540
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
525-
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
541+
f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
526542

527543
// Step 4: Special consideration for conversion to 0.5.
528544
Value cF32MantissaMask =
@@ -538,7 +554,6 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
538554
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
539555
Value isZeroExp =
540556
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
541-
542557
Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
543558
Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
544559
Value subResult =
@@ -719,16 +734,24 @@ struct ArithExpandOpsPass
719734
if (includeF8E8M0) {
720735
arith::populateExpandF8E8M0Patterns(patterns);
721736
}
737+
if (includeF4E2M1) {
738+
arith::populateExpandF4E2M1Patterns(patterns);
739+
}
722740

723741
target.addDynamicallyLegalOp<arith::ExtFOp>(
724742
[=](arith::ExtFOp op) {
725743
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
726744
Type outETy = getElementTypeOrSelf(op.getType());
727745
bool legalTypes = true;
728-
if (includeBf16)
746+
if (includeBf16) {
729747
legalTypes &= !(inETy.isBF16() && outETy.isF32());
730-
if (includeF8E8M0)
748+
}
749+
if (includeF8E8M0) {
731750
legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
751+
}
752+
if (includeF4E2M1) {
753+
legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
754+
}
732755
return legalTypes;
733756
});
734757

@@ -737,10 +760,15 @@ struct ArithExpandOpsPass
737760
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
738761
Type outETy = getElementTypeOrSelf(op.getType());
739762
bool legalTypes = true;
740-
if (includeBf16)
763+
if (includeBf16) {
741764
legalTypes &= !(inETy.isF32() && outETy.isBF16());
742-
if (includeF8E8M0)
765+
}
766+
if (includeF8E8M0) {
743767
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
768+
}
769+
if (includeF4E2M1) {
770+
legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
771+
}
744772
return legalTypes;
745773
});
746774

@@ -765,6 +793,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
765793
patterns.getContext());
766794
}
767795

796+
void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
797+
patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
798+
patterns.getContext());
799+
}
800+
768801
void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
769802
patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
770803
patterns.getContext());
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN {
4+
%0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN
5+
return %0 : f4E2M1FN
6+
}
7+
8+
// CHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN
9+
// CHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
10+
// CHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32
11+
// CHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN
12+
// CHECK: return %[[RESULT]]
13+
14+
// -----
15+
16+
func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> {
17+
%0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN>
18+
return %0 : vector<4xf6E3M2FN>
19+
}
20+
21+
// CHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN
22+
// CHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
23+
// CHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16>
24+
// CHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN>
25+
// CHECK: return %[[RESULT]] : vector<4xf6E3M2FN>
26+
27+
// -----
28+
29+
func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> {
30+
%0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath<fast> : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN>
31+
return %0 : vector<4xf6E3M2FN>
32+
}
33+
// CHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math
34+
// CHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
35+
// CHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf16>
36+
// CHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath<fast> : vector<4xf16>
37+
// CHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath<fast> : vector<4xf16> to vector<4xf6E3M2FN>
38+
// CHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN>
39+
40+
// -----
41+
42+
func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN {
43+
%0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN
44+
return %0 : f4E2M1FN
45+
}
46+
// CHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales
47+
// CHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN
48+
// CHECK: return
49+
50+
// -----
51+
func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> {
52+
%0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN>
53+
return %0 : vector<4xf4E2M1FN>
54+
}
55+
// CHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales
56+
// CHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
57+
// CHECK: return
58+
59+
// -----
60+
61+
func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 {
62+
%0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32
63+
return %0 : f32
64+
}
65+
66+
// CHECK-LABEL: @scaling_extf_to_f32
67+
// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
68+
// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
69+
// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
70+
// CHECK: return %[[RESULT]]
71+
72+
// -----
73+
74+
func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 {
75+
%0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32
76+
return %0 : f32
77+
}
78+
79+
// CHECK-LABEL: @scaling_extf_to_f32_using_f16_scales
80+
// CHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU
81+
// CHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32
82+
// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32
83+
// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32
84+
// CHECK: return %[[RESULT]]
85+
86+
// -----
87+
88+
func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 {
89+
// expected-error@+1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}}
90+
%0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32
91+
return %0 : f32
92+
}
93+
94+
// -----
95+
96+
func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
97+
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32>
98+
return %0 : vector<4xf32>
99+
}
100+
101+
// CHECK-LABEL: @scaling_extf_vector_to_f32
102+
// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32>
103+
// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
104+
// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
105+
// CHECK: return %[[RESULT]]
106+
107+
// -----
108+
109+
func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
110+
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16>
111+
return %0 : vector<4xf16>
112+
}
113+
114+
// CHECK-LABEL: @scaling_extf_vector_to_f16
115+
// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16>
116+
// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16>
117+
// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16>
118+
// CHECK: return %[[RESULT]]
119+
120+
// -----
121+
122+
func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
123+
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16>
124+
return %0 : vector<4xbf16>
125+
}
126+
127+
// CHECK-LABEL: @scaling_extf_vector_to_bf16
128+
// CHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16>
129+
// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16>
130+
// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16>
131+
// CHECK: return %[[RESULT]]
132+
133+
// -----
134+
135+
func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
136+
%0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
137+
return %0 : vector<4xf32>
138+
}
139+
140+
// CHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales
141+
// CHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU>
142+
// CHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32>
143+
// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
144+
// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32>
145+
// CHECK: return %[[RESULT]]
146+
147+
// -----
148+
149+
func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> {
150+
%0 = arith.scaling_extf %arg0, %arg1 fastmath<fast> : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32>
151+
return %0 : vector<4xf32>
152+
}
153+
154+
// CHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath
155+
// CHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath<fast> : vector<4xf16> to vector<4xf8E8M0FNU>
156+
// CHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath<fast> : vector<4xf8E8M0FNU> to vector<4xf32>
157+
// CHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath<fast> : vector<4xf4E2M1FN> to vector<4xf32>
158+
// CHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath<fast> : vector<4xf32>
159+
// CHECK: return %[[RESULT]]

0 commit comments

Comments
 (0)