|
14 | 14 | // |
15 | 15 | //===----------------------------------------------------------------------===// |
16 | 16 |
|
| 17 | +// TODO: uses or report_fatal_error (which is also deprecated) / |
| 18 | +// ReportFatalUsageError in this file should be refactored, as per LLVM |
| 19 | +// best practices, to rely on the Diagnostic infrastructure. |
| 20 | + |
17 | 21 | #include "SPIRVModuleAnalysis.h" |
18 | 22 | #include "MCTargetDesc/SPIRVBaseInfo.h" |
19 | 23 | #include "MCTargetDesc/SPIRVMCTargetDesc.h" |
@@ -1071,13 +1075,50 @@ static bool isBFloat16Type(const SPIRVType *TypeDef) { |
1071 | 1075 | #define ATOM_FLT_REQ_EXT_MSG(ExtName) \ |
1072 | 1076 | "The atomic float instruction requires the following SPIR-V " \ |
1073 | 1077 | "extension: SPV_EXT_shader_atomic_float" ExtName |
| 1078 | +static void AddAtomicVectorFloatRequirements(const MachineInstr &MI, |
| 1079 | + SPIRV::RequirementHandler &Reqs, |
| 1080 | + const SPIRVSubtarget &ST) { |
| 1081 | + SPIRVType *VecTypeDef = |
| 1082 | + MI.getMF()->getRegInfo().getVRegDef(MI.getOperand(1).getReg()); |
| 1083 | + |
| 1084 | + const unsigned Rank = VecTypeDef->getOperand(2).getImm(); |
| 1085 | + if (Rank != 2 && Rank != 4) |
| 1086 | + reportFatalUsageError("Result type of an atomic vector float instruction " |
| 1087 | + "must be a 2-component or 4 component vector"); |
| 1088 | + |
| 1089 | + SPIRVType *EltTypeDef = |
| 1090 | + MI.getMF()->getRegInfo().getVRegDef(VecTypeDef->getOperand(1).getReg()); |
| 1091 | + |
| 1092 | + if (EltTypeDef->getOpcode() != SPIRV::OpTypeFloat || |
| 1093 | + EltTypeDef->getOperand(1).getImm() != 16) |
| 1094 | + reportFatalUsageError( |
| 1095 | + "The element type for the result type of an atomic vector float " |
| 1096 | + "instruction must be a 16-bit floating-point scalar"); |
| 1097 | + |
| 1098 | + if (isBFloat16Type(EltTypeDef)) |
| 1099 | + reportFatalUsageError( |
| 1100 | + "The element type for the result type of an atomic vector float " |
| 1101 | + "instruction cannot be a bfloat16 scalar"); |
| 1102 | + if (!ST.canUseExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector)) |
| 1103 | + reportFatalUsageError( |
| 1104 | + "The atomic float16 vector instruction requires the following SPIR-V " |
| 1105 | + "extension: SPV_NV_shader_atomic_fp16_vector"); |
| 1106 | + |
| 1107 | + Reqs.addExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector); |
| 1108 | + Reqs.addCapability(SPIRV::Capability::AtomicFloat16VectorNV); |
| 1109 | +} |
| 1110 | + |
1074 | 1111 | static void AddAtomicFloatRequirements(const MachineInstr &MI, |
1075 | 1112 | SPIRV::RequirementHandler &Reqs, |
1076 | 1113 | const SPIRVSubtarget &ST) { |
1077 | 1114 | assert(MI.getOperand(1).isReg() && |
1078 | 1115 | "Expect register operand in atomic float instruction"); |
1079 | 1116 | Register TypeReg = MI.getOperand(1).getReg(); |
1080 | 1117 | SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg); |
| 1118 | + |
| 1119 | + if (TypeDef->getOpcode() == SPIRV::OpTypeVector) |
| 1120 | + return AddAtomicVectorFloatRequirements(MI, Reqs, ST); |
| 1121 | + |
1081 | 1122 | if (TypeDef->getOpcode() != SPIRV::OpTypeFloat) |
1082 | 1123 | report_fatal_error("Result type of an atomic float instruction must be a " |
1083 | 1124 | "floating-point type scalar"); |
|
0 commit comments