-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][AMDGPU] Add int4 intrinsics, mixed-type fp8 to handle gfx12 #128963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -403,8 +403,11 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, | |
| /// Push an input operand. If it is a float type, nothing to do. If it is | ||
| /// an integer type, then we need to also push its signdness (1 for signed, 0 | ||
| /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 | ||
| /// vector. We also need to convert bfloat inputs to i16 to account for the lack | ||
| /// of bfloat support in the WMMA intrinsics themselves. | ||
| /// vector (or the 8xi8 vector into a 2xi32 one for gfx12+). | ||
| /// We also need to convert bfloat inputs to i16 to account for the bfloat | ||
| /// intrinsics having been defined before the AMD backend supported bfloat. We | ||
| /// similarly need to pack 8-bit float types into integers as if they were i8 | ||
| /// (which they are for the backend's purposes). | ||
| static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, | ||
| Location loc, | ||
| const TypeConverter *typeConverter, | ||
|
|
@@ -413,12 +416,16 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, | |
| SmallVector<Value, 4> &operands) { | ||
| Type inputType = llvmInput.getType(); | ||
| auto vectorType = dyn_cast<VectorType>(inputType); | ||
| if (!vectorType) { | ||
| operands.push_back(llvmInput); | ||
| return; | ||
| } | ||
| Type elemType = vectorType.getElementType(); | ||
|
|
||
| if (elemType.isBF16()) | ||
| llvmInput = rewriter.create<LLVM::BitcastOp>( | ||
| loc, vectorType.clone(rewriter.getI16Type()), llvmInput); | ||
| if (!elemType.isInteger(8)) { | ||
| if (elemType.getIntOrFloatBitWidth() > 8) { | ||
| operands.push_back(llvmInput); | ||
| return; | ||
| } | ||
|
|
@@ -427,25 +434,33 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, | |
| // for int8. This is because, in LLVM, fp8 type is converted to int8, so the | ||
| // fp8/int8 information is lost during the conversion process. | ||
| auto mlirInputType = cast<VectorType>(mlirInput.getType()); | ||
| bool isInputInt8 = mlirInputType.getElementType().isInteger(8); | ||
| if (isInputInt8) { | ||
| bool isInputInteger = mlirInputType.getElementType().isInteger(); | ||
| if (isInputInteger) { | ||
| // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag | ||
| bool localIsUnsigned = isUnsigned; | ||
| if (elemType.isUnsignedInteger(8)) { | ||
| if (elemType.isUnsignedInteger()) { | ||
| localIsUnsigned = true; | ||
| } else if (elemType.isSignedInteger(8)) { | ||
| } else if (elemType.isSignedInteger()) { | ||
| localIsUnsigned = false; | ||
| } | ||
| Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); | ||
| operands.push_back(sign); | ||
| } | ||
|
|
||
| int64_t numBytes = vectorType.getNumElements(); | ||
| int64_t numBits = | ||
| vectorType.getNumElements() * elemType.getIntOrFloatBitWidth(); | ||
| Type i32 = rewriter.getI32Type(); | ||
| VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32); | ||
| auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits); | ||
| Type intrinsicInType = numBits <= 32 | ||
| ? (Type)rewriter.getIntegerType(numBits) | ||
| : (Type)VectorType::get(numBits / 32, i32); | ||
| auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType); | ||
| Value result = rewriter.createOrFold<LLVM::BitcastOp>( | ||
|
||
| loc, llvmVectorType32bits, llvmInput); | ||
| loc, llvmIntrinsicInType, llvmInput); | ||
| // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need | ||
| // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments. | ||
| // Add in the zeros here. | ||
| if (numBits < 32) | ||
| result = rewriter.create<LLVM::ZExtOp>(loc, i32, result); | ||
| operands.push_back(result); | ||
| } | ||
|
|
||
|
|
@@ -454,7 +469,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, | |
| /// since the same numbers of VGPRs is used, we need to decide if to store the | ||
| /// result in the upper 16 bits of the VGPRs or in the lower part. To store the | ||
| /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will | ||
| /// be stored it in the upper part | ||
| /// be stored it in the upper part. The subwordOffset must not be set for gfx12, | ||
| /// as the instructions have been changed to return fewer registers instead. | ||
| static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, | ||
| Location loc, | ||
| const TypeConverter *typeConverter, | ||
|
|
@@ -617,8 +633,10 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma, | |
| static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, | ||
| Chipset chipset) { | ||
| auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType()); | ||
| auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType()); | ||
| auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType()); | ||
| auto elemSourceType = sourceVectorType.getElementType(); | ||
| auto elemBSourceType = sourceBVectorType.getElementType(); | ||
| auto elemDestType = destVectorType.getElementType(); | ||
|
|
||
| if (elemSourceType.isF16() && elemDestType.isF32()) | ||
|
|
@@ -631,10 +649,33 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, | |
| return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); | ||
| if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) | ||
| return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); | ||
| if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32()) | ||
| return ROCDL::wmma_f32_16x16x16_fp8::getOperationName(); | ||
| if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32()) | ||
| return ROCDL::wmma_f32_16x16x16_bf8::getOperationName(); | ||
| if (chipset.majorVersion == 11) { | ||
| if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) | ||
| return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); | ||
| } | ||
| if (chipset.majorVersion >= 12) { | ||
| if (isa<Float8E4M3FNType>(elemSourceType) && | ||
| isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) | ||
| return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); | ||
| if (isa<Float8E4M3FNType>(elemSourceType) && | ||
| isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) | ||
| return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName(); | ||
| if (isa<Float8E5M2Type>(elemSourceType) && | ||
| isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) | ||
| return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName(); | ||
| if (isa<Float8E5M2Type>(elemSourceType) && | ||
| isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) | ||
| return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); | ||
| if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) { | ||
| bool isWave64 = destVectorType.getNumElements() == 4; | ||
| // This is the ambiguous case. 8 inputs to the wave64 version means that | ||
| // we want the 16x16x32 version, but for wave32 they mean the short form. | ||
| bool has8Inputs = sourceVectorType.getNumElements() == 8; | ||
| if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: if(isWave64 == has8Inputs)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I figured I'd use a somewhat verbose chunk of logic to make it clear what the cases are |
||
| return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); | ||
| return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); | ||
| } | ||
| } | ||
| return std::nullopt; | ||
| } | ||
|
|
||
|
|
@@ -712,6 +753,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { | |
| if (!maybeIntrinsic.has_value()) | ||
| return op.emitOpError("no intrinsic matching WMMA on the given chipset"); | ||
|
|
||
| if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0) | ||
| return op.emitOpError("subwordOffset not supported on gfx12+"); | ||
|
|
||
| OperationState loweredOp(loc, *maybeIntrinsic); | ||
| loweredOp.addTypes(rawOutType); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,68 @@ | ||
| // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s | ||
| func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>, %arg2 : vector<8xf32>) { | ||
| // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> | ||
| amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32> | ||
| // CHECK-LABEL: @wmma_to_rocdl | ||
| func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, | ||
| %arg2 : vector<8xf32>, %arg3 : vector<4xf32>, | ||
| %arg4 : vector<8xbf16>, %arg5 : vector<4xbf16>, | ||
| %arg6 : vector<8xf8E4M3FN>, %arg7 : vector<4xf8E4M3FN>, | ||
| %arg8 : vector<8xf8E5M2>, %arg9 : vector<4xf8E5M2>, | ||
| %arg10 : vector<8xi8>, %arg11 : vector<4xi8>, | ||
| %arg12 : vector<8xi32>, %arg13 : vector<4xi32>, | ||
| %arg14 : vector<16xi4>, %arg15 : vector<8xi4>, %arg16 : vector<4xi4>) { | ||
| // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf32>) -> vector<8xf32> | ||
| amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32> | ||
| // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32> | ||
| amdgpu.wmma %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32> | ||
|
|
||
| // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xf32>) -> vector<8xf32> | ||
| amdgpu.wmma %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32> | ||
| // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32> | ||
| amdgpu.wmma %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32> | ||
|
|
||
| // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16> | ||
| amdgpu.wmma %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16> | ||
| // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16> | ||
| amdgpu.wmma %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16> | ||
|
|
||
| // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16> | ||
| // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> | ||
| amdgpu.wmma %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16> | ||
| // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16> | ||
| amdgpu.wmma %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16> | ||
|
|
||
| // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> | ||
| amdgpu.wmma %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32> | ||
| // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> | ||
| amdgpu.wmma %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32> | ||
|
|
||
| // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> | ||
| amdgpu.wmma %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32> | ||
| // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> | ||
| amdgpu.wmma %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32> | ||
|
|
||
| // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> | ||
| amdgpu.wmma %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32> | ||
| // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> | ||
| amdgpu.wmma %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32> | ||
|
|
||
| // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> | ||
| amdgpu.wmma %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32> | ||
| // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> | ||
| amdgpu.wmma %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32> | ||
|
|
||
| // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> | ||
| amdgpu.wmma %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32> | ||
| // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> | ||
| amdgpu.wmma %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32> | ||
|
|
||
| // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> | ||
| amdgpu.wmma %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32> | ||
| // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> | ||
| amdgpu.wmma %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32> | ||
|
|
||
| // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32> | ||
| amdgpu.wmma %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32> | ||
| // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> | ||
| amdgpu.wmma %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32> | ||
|
|
||
| // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> | ||
| amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32> | ||
| func.return | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the output can be f32 or i32 as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in the context of the f16/bf16-outputting instructions