Skip to content

Commit fae7b3a

Browse files
committed
update to use builtin types
1 parent 12ceea5 commit fae7b3a

File tree

4 files changed

+29
-19
lines changed

4 files changed

+29
-19
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,17 +1654,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
16541654
}];
16551655
}
16561656

1657-
def ConvertFP4E2M1 : I32EnumCase<"E2M1", 0, "e2m1">;
1658-
1659-
def ConvertFP4Type : I32Enum<"ConvertFP4Type", "NVVM ConvertFP4Type kind",
1660-
[ConvertFP4E2M1]> {
1661-
let cppNamespace = "::mlir::NVVM";
1662-
}
1663-
def ConvertFP4TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP4Type,
1664-
"convert_fp4_type"> {
1665-
let assemblyFormat = "`<` $value `>`";
1666-
}
1667-
16681657
def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
16691658
let summary = "Convert a pair of float inputs to f4x2";
16701659
let description = [{
@@ -1680,16 +1669,18 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
16801669
}];
16811670

16821671
let results = (outs I8:$dst);
1683-
let arguments = (ins ConvertFP4TypeAttr:$type, F32:$a, F32:$b,
1684-
DefaultValuedAttr<BoolAttr, "false">:$relu);
1685-
let assemblyFormat = "$type $a `,` $b attr-dict";
1672+
let arguments = (ins F32:$a, F32:$b,
1673+
DefaultValuedAttr<BoolAttr, "false">:$relu,
1674+
TypeAttr:$dstTy);
1675+
let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
1676+
let hasVerifier = 1;
16861677

16871678
let extraClassDeclaration = [{
1688-
static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP4Type, bool hasRelu);
1679+
static llvm::Intrinsic::ID getIntrinsicID(bool hasRelu);
16891680
}];
16901681

16911682
string llvmBuilder = [{
1692-
auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($type, $relu);
1683+
auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($relu);
16931684
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
16941685
$dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
16951686
}];

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
271271
return success();
272272
}
273273

274+
LogicalResult ConvertF32x2ToF4x2Op::verify() {
275+
mlir::MLIRContext *ctx = getContext();
276+
277+
if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
278+
return emitOpError("Only ")
279+
<< mlir::Float4E2M1FNType::get(ctx)
280+
<< " type is supported for conversions from f32x2 to f4x2.";
281+
282+
return success();
283+
}
284+
274285
LogicalResult BulkStoreOp::verify() {
275286
if (getInitVal() != 0)
276287
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1977,7 +1988,7 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
19771988
}
19781989

19791990
llvm::Intrinsic::ID
1980-
ConvertF32x2ToF4x2Op::getIntrinsicID(NVVM::ConvertFP4Type type, bool hasRelu) {
1991+
ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
19811992
return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
19821993
: llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
19831994
}

mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
55
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
66
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
7-
%res1 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB
7+
%res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN)
88
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
99
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
10-
%res2 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB {relu = true}
10+
%res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
1111
llvm.return
1212
}

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,14 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
246246

247247
// -----
248248

249+
llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
250+
// expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}}
251+
%res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN)
252+
llvm.return
253+
}
254+
255+
// -----
256+
249257
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
250258
// expected-error @below {{cache eviction priority supported only for cache level L2}}
251259
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>

0 commit comments

Comments
 (0)