-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[RISCV] Support Inline ASM for the bf16 type. #80118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-risc-v Author: Chuan-Yue Yuan (circYuan) ChangesThis patch makes the RISCV-V asm constraint Full diff: https://github.com/llvm/llvm-project/pull/80118.diff 3 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5ce1013f30fd1..7342e7bcba1f2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -9105,7 +9105,7 @@ getRegistersForValue(SelectionDAG &DAG, const SDLoc &DL,
// Get the actual register value type. This is important, because the user
// may have asked for (e.g.) the AX register in i32 type. We need to
// remember that AX is actually i16 to get the right extension.
- const MVT RegVT = *TRI.legalclasstypes_begin(*RC);
+ MVT RegVT = *TRI.legalclasstypes_begin(*RC);
if (OpInfo.ConstraintVT != MVT::Other && RegVT != MVT::Untyped) {
// If this is an FP operand in an integer register (or visa versa), or more
@@ -9139,6 +9139,17 @@ getRegistersForValue(SelectionDAG &DAG, const SDLoc &DL,
DAG.getNode(ISD::BITCAST, DL, VT, OpInfo.CallOperand);
OpInfo.ConstraintVT = VT;
}
+ // If the RegisterClass contains more than one types like RISCV
+ // FPR16RegClass which has [f16, bf16], We should check if the
+ // OpInfo.ConstraintVT can directly be assigned to the RegVT.
+ } else if ((OpInfo.Type == InlineAsm::isOutput ||
+ OpInfo.Type == InlineAsm::isInput) &&
+ TRI.isTypeLegalForClass(*RC, OpInfo.ConstraintVT)) {
+ if (RegVT != OpInfo.ConstraintVT &&
+ RegVT.getSizeInBits() == OpInfo.ConstraintVT.getSizeInBits() &&
+ RegVT.isFloatingPoint() && OpInfo.ConstraintVT.isFloatingPoint()) {
+ RegVT = OpInfo.ConstraintVT;
+ }
}
}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b8994e7b7bdb2..7a6e41ab7fee3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -19225,6 +19225,8 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
return std::make_pair(0U, &RISCV::GPRPairRegClass);
return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
case 'f':
+ if (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16)
+ return std::make_pair(0U, &RISCV::FPR16RegClass);
if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16)
return std::make_pair(0U, &RISCV::FPR16RegClass);
if (Subtarget.hasStdExtF() && VT == MVT::f32)
@@ -19343,6 +19345,11 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
unsigned HReg = RISCV::F0_H + RegNo;
return std::make_pair(HReg, &RISCV::FPR16RegClass);
}
+ if (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16) {
+ unsigned RegNo = FReg - RISCV::F0_F;
+ unsigned HReg = RISCV::F0_H + RegNo;
+ return std::make_pair(HReg, &RISCV::FPR16RegClass);
+ }
}
}
diff --git a/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll b/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll
new file mode 100644
index 0000000000000..a496e2fea173e
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/inline-asm-f-constraint-bf16.ll
@@ -0,0 +1,77 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv32 -mattr=+f,+experimental-zfbfmin -target-abi=ilp32 -verify-machineinstrs < %s \
+; RUN: | FileCheck -check-prefix=RV32F %s
+; RUN: llc -mtriple=riscv64 -mattr=+f,+experimental-zfbfmin -target-abi=lp64 -verify-machineinstrs < %s \
+; RUN: | FileCheck -check-prefix=RV64F %s
+; RUN: llc -mtriple=riscv32 -mattr=+d,+experimental-zfbfmin -target-abi=ilp32 -verify-machineinstrs < %s \
+; RUN: | FileCheck -check-prefix=RV32F %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+experimental-zfbfmin -target-abi=lp64 -verify-machineinstrs < %s \
+; RUN: | FileCheck -check-prefix=RV64F %s
+
+@gf = external global float
+
+define float @constraint_f_float(bfloat %a) nounwind {
+; RV32F-LABEL: constraint_f_float:
+; RV32F: # %bb.0:
+; RV32F-NEXT: fmv.h.x fa5, a0
+; RV32F-NEXT: #APP
+; RV32F-NEXT: fcvt.s.bf16 fa5, fa5
+; RV32F-NEXT: #NO_APP
+; RV32F-NEXT: fmv.x.w a0, fa5
+; RV32F-NEXT: ret
+;
+; RV64F-LABEL: constraint_f_float:
+; RV64F: # %bb.0:
+; RV64F-NEXT: fmv.h.x fa5, a0
+; RV64F-NEXT: #APP
+; RV64F-NEXT: fcvt.s.bf16 fa5, fa5
+; RV64F-NEXT: #NO_APP
+; RV64F-NEXT: fmv.x.w a0, fa5
+; RV64F-NEXT: ret
+ %1 = load float, ptr @gf
+ %2 = tail call float asm "fcvt.s.bf16 $0, $1", "=f,f"(bfloat %a)
+ ret float %2
+}
+
+define float @constraint_f_float_abi_name(bfloat %a) nounwind {
+; RV32F-LABEL: constraint_f_float_abi_name:
+; RV32F: # %bb.0:
+; RV32F-NEXT: fmv.h.x fa0, a0
+; RV32F-NEXT: #APP
+; RV32F-NEXT: fcvt.s.bf16 ft0, fa0
+; RV32F-NEXT: #NO_APP
+; RV32F-NEXT: fmv.x.w a0, ft0
+; RV32F-NEXT: ret
+;
+; RV64F-LABEL: constraint_f_float_abi_name:
+; RV64F: # %bb.0:
+; RV64F-NEXT: fmv.h.x fa0, a0
+; RV64F-NEXT: #APP
+; RV64F-NEXT: fcvt.s.bf16 ft0, fa0
+; RV64F-NEXT: #NO_APP
+; RV64F-NEXT: fmv.x.w a0, ft0
+; RV64F-NEXT: ret
+ %1 = load float, ptr @gf
+ %2 = tail call float asm "fcvt.s.bf16 $0, $1", "={ft0},{fa0}"(bfloat %a)
+ ret float %2
+}
+
+define bfloat @constraint_gpr(bfloat %x) {
+; RV32F-LABEL: constraint_gpr:
+; RV32F: # %bb.0:
+; RV32F-NEXT: .cfi_def_cfa_offset 0
+; RV32F-NEXT: #APP
+; RV32F-NEXT: mv a0, a0
+; RV32F-NEXT: #NO_APP
+; RV32F-NEXT: ret
+;
+; RV64F-LABEL: constraint_gpr:
+; RV64F: # %bb.0:
+; RV64F-NEXT: .cfi_def_cfa_offset 0
+; RV64F-NEXT: #APP
+; RV64F-NEXT: mv a0, a0
+; RV64F-NEXT: #NO_APP
+; RV64F-NEXT: ret
+ %1 = tail call bfloat asm sideeffect alignstack "mv $0, $1", "={x10},{x10}"(bfloat %x)
+ ret bfloat %1
+}
|
| unsigned HReg = RISCV::F0_H + RegNo; | ||
| return std::make_pair(HReg, &RISCV::FPR16RegClass); | ||
| } | ||
| if (Subtarget.hasStdExtZfbfmin() && VT == MVT::bf16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look like the same with the one for f16. Could we merge.
|
You don't need |
| // may have asked for (e.g.) the AX register in i32 type. We need to | ||
| // remember that AX is actually i16 to get the right extension. | ||
| const MVT RegVT = *TRI.legalclasstypes_begin(*RC); | ||
| MVT RegVT = *TRI.legalclasstypes_begin(*RC); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generic code change feels wrong. It feels really wrong that this just takes the first element of RC. Mutating this later just makes it harder to follow. I think this block of code should either not depend on RegVT at all, or this needs to be a loop over all legal class types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean that we should keep the RegVT constant, but choose not just the first legal value type for creating the SelectionDAG? Such fix is like using a lambda function with for loop for choosing the reasonable value type, in this case, not fp16 but bf16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a new hook that implements TRI.legalclasstypes_begin is not what I meant. That doesn't really address the issue, and just forces more complexity into every target. I meant all of this code here can be factored into a function and repeated for every type in legalclasstypes until one works
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
| ; RV64F-NEXT: #NO_APP | ||
| ; RV64F-NEXT: fmv.x.w a0, fa5 | ||
| ; RV64F-NEXT: ret | ||
| %1 = load float, float* @gf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to use opaque pointers
| @@ -0,0 +1,77 @@ | |||
| ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py | |||
| ; RUN: llc -mtriple=riscv32 -mattr=+f,+experimental-zfbfmin -target-abi=ilp32 -verify-machineinstrs < %s \ | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need all the -verify-machineinstrs
| ; RV64F-NEXT: ret | ||
| %1 = tail call bfloat asm sideeffect alignstack "mv $0, $1", "={x10},{x10}"(bfloat %x) | ||
| ret bfloat %1 | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add tests using virtual registers too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bitcast handling shouldn't be strictly necessary for the asm support and should be done in a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I don't combine the bitcast this time, the new test I added would be fail since RISCV doesn't have the instruction for selecting the bf16 bitcast fp16. In the IR level, it even doesn't allow the bitcast between fp16 <--> bf16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then this is entirely incorrect. You cannot rely on this DAG combine to fix up a selection failure
In the IR level, it even doesn't allow the bitcast between fp16 <--> bf16.
This is always allowed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, sorry for that I don't realize about this policy originally. I will seek the correct way to handle it, thanks for the comment!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably need to add bf16 as a legal type if you want to support registers that directly contain it (and then directly handle bitcasts in selection). That's approximately what legal means anyway. You can only get so far by hacking it into making CopyFromReg legal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, bitcast between fp16 and bf16 is allowed, sorry for the misunderstanding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be rewritten as a generic combine for any pairs where isTypeLegalForClass is true in both registers
This patch makes the RISCV-V asm constraint
frecognize the bfloat type.