|
| 1 | +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_XE_ASM_FORMAT_H |
| 2 | +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_XE_ASM_FORMAT_H |
| 3 | + |
| 4 | +#include "mlir/IR/Value.h" |
| 5 | +#include "triton/Dialect/Triton/IR/Dialect.h" |
| 6 | +#include "llvm/ADT/SmallVector.h" |
| 7 | +#include "llvm/ADT/StringRef.h" |
| 8 | +#include <memory> |
| 9 | +#include <string> |
| 10 | + |
| 11 | +namespace mlir { |
| 12 | +class ConversionPatternRewriter; |
| 13 | +class Location; |
| 14 | + |
| 15 | +namespace triton { |
| 16 | +using llvm::StringRef; |
| 17 | + |
| 18 | +struct XeInstr; |
| 19 | +struct XeInstrCommon; |
| 20 | +struct XeInstrExecution; |
| 21 | + |
| 22 | +// XeBuilder helps to manage a Xe asm program on consists of one or multiple |
| 23 | +// instructions. |
| 24 | +// |
| 25 | +// A helper for building an ASM program, the objective of XeBuilder is to give |
| 26 | +// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear. |
| 27 | +// Currently, several factors are introduced to reduce the need for mixing |
| 28 | +// string and C++ if-else code. |
| 29 | +// |
| 30 | +// Usage: |
| 31 | +// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k), |
| 32 | +// "b"(p)); |
| 33 | +// |
| 34 | +// XeBuilder builder; |
| 35 | +// auto& add = builder.create<>(); |
| 36 | +// add.predicate(pVal).o("lo").o("u32"); // add any suffix |
| 37 | +// // predicate here binds %0 to pVal, pVal is a mlir::Value |
| 38 | +// |
| 39 | +// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal |
| 40 | +// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal |
| 41 | +// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal |
| 42 | +// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate |
| 43 | +// |
| 44 | +// To get the asm code: |
| 45 | +// builder.dump() |
| 46 | +// |
| 47 | +// To get all the mlir::Value used in the Xe code, |
| 48 | +// |
| 49 | +// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal} |
| 50 | +// |
| 51 | +// To get the string containing all the constraints with "," separated, |
| 52 | +// builder.getConstraints() // get "=r,r,k" |
| 53 | +// |
| 54 | +// XeBuilder can build a Xe asm with multiple instructions, sample code: |
| 55 | +// |
| 56 | +// XeBuilder builder; |
| 57 | +// auto& mov = builder.create("mov"); |
| 58 | +// auto& cp = builder.create("cp"); |
| 59 | +// mov(...); |
| 60 | +// cp(...); |
| 61 | +// This will get a Xe code with two instructions. |
| 62 | +// |
| 63 | +// Similar to a C function, a declared XeInstr instance can be launched |
| 64 | +// multiple times with different operands, e.g. |
| 65 | +// |
| 66 | +// auto& mov = builder.create("mov"); |
| 67 | +// mov(... some operands ...); |
| 68 | +// mov(... some different operands ...); |
| 69 | +// |
| 70 | +// Finally, we will get a Xe code with two mov instructions. |
| 71 | +// |
| 72 | +// There are several derived instruction type for typical instructions, for |
| 73 | +// example, the PtxIOInstr for ld and st instructions. |
| 74 | +struct XeBuilder { |
| 75 | + struct Operand { |
| 76 | + std::string constraint; |
| 77 | + Value value; |
| 78 | + int idx{-1}; |
| 79 | + llvm::SmallVector<Operand *> list; |
| 80 | + std::function<std::string(int idx)> repr; |
| 81 | + |
| 82 | + // for list |
| 83 | + Operand() = default; |
| 84 | + Operand(const Operation &) = delete; |
| 85 | + Operand(Value value, StringRef constraint) |
| 86 | + : constraint(constraint), value(value) {} |
| 87 | + |
| 88 | + bool isList() const { return !value && constraint.empty(); } |
| 89 | + |
| 90 | + Operand *listAppend(Operand *arg) { |
| 91 | + list.push_back(arg); |
| 92 | + return this; |
| 93 | + } |
| 94 | + |
| 95 | + Operand *listGet(size_t nth) const { |
| 96 | + assert(nth < list.size() && |
| 97 | + "get asm operands of Xe assembler out of range."); |
| 98 | + return list[nth]; |
| 99 | + } |
| 100 | + |
| 101 | + std::string dump() const; |
| 102 | + }; |
| 103 | + |
| 104 | + template <typename INSTR = XeInstr, typename... Args> |
| 105 | + INSTR *create(Args &&...args) { |
| 106 | + instrs.emplace_back(std::make_unique<INSTR>(this, args...)); |
| 107 | + return static_cast<INSTR *>(instrs.back().get()); |
| 108 | + } |
| 109 | + |
| 110 | + // Create a list of operands. |
| 111 | + Operand *newListOperand() { return newOperand(); } |
| 112 | + |
| 113 | + Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) { |
| 114 | + auto *list = newOperand(); |
| 115 | + for (auto &item : items) { |
| 116 | + list->listAppend(newOperand(item.first, item.second)); |
| 117 | + } |
| 118 | + return list; |
| 119 | + } |
| 120 | + |
| 121 | + Operand *newListOperand(unsigned count, mlir::Value val, |
| 122 | + const std::string &constraint) { |
| 123 | + auto *list = newOperand(); |
| 124 | + for (unsigned i = 0; i < count; ++i) { |
| 125 | + list->listAppend(newOperand(val, constraint)); |
| 126 | + } |
| 127 | + return list; |
| 128 | + } |
| 129 | + |
| 130 | + Operand *newListOperand(unsigned count, const std::string &constraint) { |
| 131 | + auto *list = newOperand(); |
| 132 | + for (unsigned i = 0; i < count; ++i) { |
| 133 | + list->listAppend(newOperand(constraint)); |
| 134 | + } |
| 135 | + return list; |
| 136 | + } |
| 137 | + |
| 138 | + // Create a new operand. It will not add to operand list. |
| 139 | + // @value: the MLIR value bind to this operand. |
| 140 | + // @constraint: ASM operand constraint, .e.g. "=r" |
| 141 | + // @formatter: extra format to represent this operand in ASM code, default is |
| 142 | + // "%{0}".format(operand.idx). |
| 143 | + Operand *newOperand(mlir::Value value, StringRef constraint, |
| 144 | + std::function<std::string(int idx)> formatter = nullptr); |
| 145 | + |
| 146 | + // Create a new operand which is written to, that is, the constraint starts |
| 147 | + // with "=", e.g. "=r". |
| 148 | + // If the operand will be used in predicated execution, |
| 149 | + // users may want to initialize it before use. |
| 150 | + // Otherwise if the register is only used in the true branch or the false |
| 151 | + // branch but not both, the register is undefined and ptxas can perform |
| 152 | + // aggressive optimizations that may lead to incorrect results. |
| 153 | + Operand *newOperand(StringRef constraint, bool init = false); |
| 154 | + |
| 155 | + // Create a new operand that is tied to a previous operand. In this case the |
| 156 | + // asm would be permitted to write to an input register. Instead of providing |
| 157 | + // constraint code for this operand, the constraint code of the tied operand |
| 158 | + // is used. |
| 159 | + Operand *newOperand(unsigned operandIndex); |
| 160 | + |
| 161 | + // Create a constant integer operand. |
| 162 | + Operand *newConstantOperand(int64_t v); |
| 163 | + // Create a constant operand with explicit code specified. |
| 164 | + Operand *newConstantOperand(const std::string &v); |
| 165 | + |
| 166 | + Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0); |
| 167 | + |
| 168 | + llvm::SmallVector<Operand *, 4> getAllArgs() const; |
| 169 | + |
| 170 | + llvm::SmallVector<Value, 4> getAllMLIRArgs() const; |
| 171 | + |
| 172 | + std::string getConstraints() const; |
| 173 | + |
| 174 | + std::string dump() const; |
| 175 | + |
| 176 | + mlir::Value launch(OpBuilder &rewriter, Location loc, Type resTy, |
| 177 | + bool hasSideEffect = true, bool isAlignStack = false, |
| 178 | + ArrayRef<Attribute> attrs = {}) const; |
| 179 | + |
| 180 | +private: |
| 181 | + Operand *newOperand() { |
| 182 | + argArchive.emplace_back(std::make_unique<Operand>()); |
| 183 | + return argArchive.back().get(); |
| 184 | + } |
| 185 | + |
| 186 | + void initOperand(Operand *opr); |
| 187 | + |
| 188 | + // Make the operands in argArchive follow the provided \param order. |
| 189 | + void reorderArgArchive(ArrayRef<Operand *> order) { |
| 190 | + assert(order.size() == argArchive.size()); |
| 191 | + // The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but |
| 192 | + // it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are |
| 193 | + // determined by Xe code snippet passed from external. |
| 194 | + sort(argArchive.begin(), argArchive.end(), |
| 195 | + [&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) { |
| 196 | + auto ida = std::find(order.begin(), order.end(), a.get()); |
| 197 | + auto idb = std::find(order.begin(), order.end(), b.get()); |
| 198 | + assert(ida != order.end()); |
| 199 | + assert(idb != order.end()); |
| 200 | + return ida < idb; |
| 201 | + }); |
| 202 | + } |
| 203 | + |
| 204 | + friend struct XeInstr; |
| 205 | + friend struct XeInstrCommon; |
| 206 | + |
| 207 | +protected: |
| 208 | + llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive; |
| 209 | + llvm::SmallVector<std::unique_ptr<XeInstrCommon>, 2> instrs; |
| 210 | + llvm::SmallVector<std::unique_ptr<XeInstrExecution>, 4> executions; |
| 211 | + int oprCounter{}; |
| 212 | +}; |
| 213 | + |
| 214 | +// Xe instruction common interface. |
| 215 | +// Put the generic logic for all the instructions here. |
| 216 | +struct XeInstrCommon { |
| 217 | + explicit XeInstrCommon(XeBuilder *builder) : builder(builder) {} |
| 218 | + |
| 219 | + using Operand = XeBuilder::Operand; |
| 220 | + |
| 221 | + template <typename... ARGS, |
| 222 | + std::enable_if_t<std::conjunction_v<std::is_same<ARGS, Operand>...>, |
| 223 | + int> = 0> |
| 224 | + XeInstrExecution &operator()(ARGS *...args) { |
| 225 | + return call({args...}); |
| 226 | + } |
| 227 | + |
| 228 | + // Set operands of this instruction. |
| 229 | + XeInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs, |
| 230 | + bool onlyAttachMLIRArgs = false); |
| 231 | + |
| 232 | +protected: |
| 233 | + // "Call" the instruction with operands. |
| 234 | + // \param oprs The operands of this instruction. |
| 235 | + // \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments |
| 236 | + // to the inline Asm without generating the operand ids(such as $0, $1) in |
| 237 | + // Xe code. |
| 238 | + XeInstrExecution &call(llvm::ArrayRef<Operand *> oprs, |
| 239 | + bool onlyAttachMLIRArgs = false); |
| 240 | + |
| 241 | + XeBuilder *builder{}; |
| 242 | + llvm::SmallVector<std::string, 4> instrParts; |
| 243 | + |
| 244 | + friend struct XeInstrExecution; |
| 245 | +}; |
| 246 | + |
| 247 | +template <class ConcreteT> struct XeInstrBase : public XeInstrCommon { |
| 248 | + using Operand = XeBuilder::Operand; |
| 249 | + |
| 250 | + explicit XeInstrBase(XeBuilder *builder, const std::string &name) |
| 251 | + : XeInstrCommon(builder) { |
| 252 | + o(name); |
| 253 | + } |
| 254 | + |
| 255 | + // Append a suffix to the instruction. |
| 256 | + // e.g. XeInstr("add").o("s32") get a add.s32. |
| 257 | + // A predicate is used to tell whether to apply the suffix, so that no if-else |
| 258 | + // code needed. e.g. `XeInstr("add").o("s32", isS32).o("u32", !isS32);` will |
| 259 | + // get a `add.s32` if isS32 is true. |
| 260 | + ConcreteT &o(const std::string &suffix, bool predicate = true) { |
| 261 | + if (predicate) |
| 262 | + instrParts.push_back(suffix); |
| 263 | + return *static_cast<ConcreteT *>(this); |
| 264 | + } |
| 265 | +}; |
| 266 | + |
| 267 | +struct XeInstr : public XeInstrBase<XeInstr> { |
| 268 | + using XeInstrBase<XeInstr>::XeInstrBase; |
| 269 | + |
| 270 | + // Append a ".global" to the instruction. |
| 271 | + XeInstr &global(); |
| 272 | + |
| 273 | + // Append a ".shared" to the instruction. |
| 274 | + XeInstr &shared(); |
| 275 | + |
| 276 | + // Append a ".v[0-9]+" to the instruction |
| 277 | + XeInstr &v(int vecWidth, bool predicate = true); |
| 278 | + |
| 279 | + // Append a".b[0-9]+" to the instruction |
| 280 | + XeInstr &b(int width); |
| 281 | +}; |
| 282 | + |
| 283 | +// Record the operands and context for "launching" a XeInstr. |
| 284 | +struct XeInstrExecution { |
| 285 | + using Operand = XeBuilder::Operand; |
| 286 | + |
| 287 | + llvm::SmallVector<Operand *> argsInOrder; |
| 288 | + |
| 289 | + XeInstrExecution() = default; |
| 290 | + explicit XeInstrExecution(XeInstrCommon *instr, |
| 291 | + llvm::ArrayRef<Operand *> oprs, |
| 292 | + bool onlyAttachMLIRArgs) |
| 293 | + : argsInOrder(oprs.begin(), oprs.end()), instr(instr), |
| 294 | + onlyAttachMLIRArgs(onlyAttachMLIRArgs) {} |
| 295 | + |
| 296 | + // Prefix a predicate to the instruction. |
| 297 | + XeInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") { |
| 298 | + pred = instr->builder->newOperand(value, constraint); |
| 299 | + return *this; |
| 300 | + } |
| 301 | + |
| 302 | + // Prefix a !predicate to the instruction. |
| 303 | + XeInstrExecution &predicateNot(mlir::Value value, StringRef constraint) { |
| 304 | + pred = instr->builder->newOperand(value, constraint); |
| 305 | + pred->repr = [](int idx) { return "@!$" + std::to_string(idx); }; |
| 306 | + return *this; |
| 307 | + } |
| 308 | + |
| 309 | + std::string dump() const; |
| 310 | + |
| 311 | + SmallVector<Operand *> getArgList() const; |
| 312 | + |
| 313 | + XeInstrCommon *instr{}; |
| 314 | + Operand *pred{}; |
| 315 | + bool onlyAttachMLIRArgs{}; |
| 316 | +}; |
| 317 | + |
| 318 | +} // namespace triton |
| 319 | +} // namespace mlir |
| 320 | + |
| 321 | +#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_XE_ASM_FORMAT_H |
0 commit comments