Skip to content

Commit 88d8343

Browse files
authored
Add ASM builder for Xe GPU. (#3822)
The XeAsmFormat.h and XeAsmFormat.cpp are almost same to the PTX ASM helper for NV backend to keep the feature for SIMT ISA.
1 parent f69ba3e commit 88d8343

File tree

8 files changed

+653
-1
lines changed

8 files changed

+653
-1
lines changed
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_XE_ASM_FORMAT_H
2+
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_XE_ASM_FORMAT_H
3+
4+
#include "mlir/IR/Value.h"
5+
#include "triton/Dialect/Triton/IR/Dialect.h"
6+
#include "llvm/ADT/SmallVector.h"
7+
#include "llvm/ADT/StringRef.h"
8+
#include <memory>
9+
#include <string>
10+
11+
namespace mlir {
12+
class ConversionPatternRewriter;
13+
class Location;
14+
15+
namespace triton {
16+
using llvm::StringRef;
17+
18+
struct XeInstr;
19+
struct XeInstrCommon;
20+
struct XeInstrExecution;
21+
22+
// XeBuilder helps to manage a Xe asm program on consists of one or multiple
23+
// instructions.
24+
//
25+
// A helper for building an ASM program, the objective of XeBuilder is to give
26+
// a thin encapsulation and make the ASM code for MLIR LLVM Dialect more clear.
27+
// Currently, several factors are introduced to reduce the need for mixing
28+
// string and C++ if-else code.
29+
//
30+
// Usage:
31+
// To build: @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k),
32+
// "b"(p));
33+
//
34+
// XeBuilder builder;
35+
// auto& add = builder.create<>();
36+
// add.predicate(pVal).o("lo").o("u32"); // add any suffix
37+
// // predicate here binds %0 to pVal, pVal is a mlir::Value
38+
//
39+
// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal
40+
// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal
41+
// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal
42+
// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate
43+
//
44+
// To get the asm code:
45+
// builder.dump()
46+
//
47+
// To get all the mlir::Value used in the Xe code,
48+
//
49+
// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal}
50+
//
51+
// To get the string containing all the constraints with "," separated,
52+
// builder.getConstraints() // get "=r,r,k"
53+
//
54+
// XeBuilder can build a Xe asm with multiple instructions, sample code:
55+
//
56+
// XeBuilder builder;
57+
// auto& mov = builder.create("mov");
58+
// auto& cp = builder.create("cp");
59+
// mov(...);
60+
// cp(...);
61+
// This will get a Xe code with two instructions.
62+
//
63+
// Similar to a C function, a declared XeInstr instance can be launched
64+
// multiple times with different operands, e.g.
65+
//
66+
// auto& mov = builder.create("mov");
67+
// mov(... some operands ...);
68+
// mov(... some different operands ...);
69+
//
70+
// Finally, we will get a Xe code with two mov instructions.
71+
//
72+
// There are several derived instruction type for typical instructions, for
73+
// example, the PtxIOInstr for ld and st instructions.
74+
struct XeBuilder {
75+
struct Operand {
76+
std::string constraint;
77+
Value value;
78+
int idx{-1};
79+
llvm::SmallVector<Operand *> list;
80+
std::function<std::string(int idx)> repr;
81+
82+
// for list
83+
Operand() = default;
84+
Operand(const Operation &) = delete;
85+
Operand(Value value, StringRef constraint)
86+
: constraint(constraint), value(value) {}
87+
88+
bool isList() const { return !value && constraint.empty(); }
89+
90+
Operand *listAppend(Operand *arg) {
91+
list.push_back(arg);
92+
return this;
93+
}
94+
95+
Operand *listGet(size_t nth) const {
96+
assert(nth < list.size() &&
97+
"get asm operands of Xe assembler out of range.");
98+
return list[nth];
99+
}
100+
101+
std::string dump() const;
102+
};
103+
104+
template <typename INSTR = XeInstr, typename... Args>
105+
INSTR *create(Args &&...args) {
106+
instrs.emplace_back(std::make_unique<INSTR>(this, args...));
107+
return static_cast<INSTR *>(instrs.back().get());
108+
}
109+
110+
// Create a list of operands.
111+
Operand *newListOperand() { return newOperand(); }
112+
113+
Operand *newListOperand(ArrayRef<std::pair<mlir::Value, std::string>> items) {
114+
auto *list = newOperand();
115+
for (auto &item : items) {
116+
list->listAppend(newOperand(item.first, item.second));
117+
}
118+
return list;
119+
}
120+
121+
Operand *newListOperand(unsigned count, mlir::Value val,
122+
const std::string &constraint) {
123+
auto *list = newOperand();
124+
for (unsigned i = 0; i < count; ++i) {
125+
list->listAppend(newOperand(val, constraint));
126+
}
127+
return list;
128+
}
129+
130+
Operand *newListOperand(unsigned count, const std::string &constraint) {
131+
auto *list = newOperand();
132+
for (unsigned i = 0; i < count; ++i) {
133+
list->listAppend(newOperand(constraint));
134+
}
135+
return list;
136+
}
137+
138+
// Create a new operand. It will not add to operand list.
139+
// @value: the MLIR value bind to this operand.
140+
// @constraint: ASM operand constraint, .e.g. "=r"
141+
// @formatter: extra format to represent this operand in ASM code, default is
142+
// "%{0}".format(operand.idx).
143+
Operand *newOperand(mlir::Value value, StringRef constraint,
144+
std::function<std::string(int idx)> formatter = nullptr);
145+
146+
// Create a new operand which is written to, that is, the constraint starts
147+
// with "=", e.g. "=r".
148+
// If the operand will be used in predicated execution,
149+
// users may want to initialize it before use.
150+
// Otherwise if the register is only used in the true branch or the false
151+
// branch but not both, the register is undefined and ptxas can perform
152+
// aggressive optimizations that may lead to incorrect results.
153+
Operand *newOperand(StringRef constraint, bool init = false);
154+
155+
// Create a new operand that is tied to a previous operand. In this case the
156+
// asm would be permitted to write to an input register. Instead of providing
157+
// constraint code for this operand, the constraint code of the tied operand
158+
// is used.
159+
Operand *newOperand(unsigned operandIndex);
160+
161+
// Create a constant integer operand.
162+
Operand *newConstantOperand(int64_t v);
163+
// Create a constant operand with explicit code specified.
164+
Operand *newConstantOperand(const std::string &v);
165+
166+
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
167+
168+
llvm::SmallVector<Operand *, 4> getAllArgs() const;
169+
170+
llvm::SmallVector<Value, 4> getAllMLIRArgs() const;
171+
172+
std::string getConstraints() const;
173+
174+
std::string dump() const;
175+
176+
mlir::Value launch(OpBuilder &rewriter, Location loc, Type resTy,
177+
bool hasSideEffect = true, bool isAlignStack = false,
178+
ArrayRef<Attribute> attrs = {}) const;
179+
180+
private:
181+
Operand *newOperand() {
182+
argArchive.emplace_back(std::make_unique<Operand>());
183+
return argArchive.back().get();
184+
}
185+
186+
void initOperand(Operand *opr);
187+
188+
// Make the operands in argArchive follow the provided \param order.
189+
void reorderArgArchive(ArrayRef<Operand *> order) {
190+
assert(order.size() == argArchive.size());
191+
// The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but
192+
// it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are
193+
// determined by Xe code snippet passed from external.
194+
sort(argArchive.begin(), argArchive.end(),
195+
[&](std::unique_ptr<Operand> &a, std::unique_ptr<Operand> &b) {
196+
auto ida = std::find(order.begin(), order.end(), a.get());
197+
auto idb = std::find(order.begin(), order.end(), b.get());
198+
assert(ida != order.end());
199+
assert(idb != order.end());
200+
return ida < idb;
201+
});
202+
}
203+
204+
friend struct XeInstr;
205+
friend struct XeInstrCommon;
206+
207+
protected:
208+
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
209+
llvm::SmallVector<std::unique_ptr<XeInstrCommon>, 2> instrs;
210+
llvm::SmallVector<std::unique_ptr<XeInstrExecution>, 4> executions;
211+
int oprCounter{};
212+
};
213+
214+
// Xe instruction common interface.
215+
// Put the generic logic for all the instructions here.
216+
struct XeInstrCommon {
217+
explicit XeInstrCommon(XeBuilder *builder) : builder(builder) {}
218+
219+
using Operand = XeBuilder::Operand;
220+
221+
template <typename... ARGS,
222+
std::enable_if_t<std::conjunction_v<std::is_same<ARGS, Operand>...>,
223+
int> = 0>
224+
XeInstrExecution &operator()(ARGS *...args) {
225+
return call({args...});
226+
}
227+
228+
// Set operands of this instruction.
229+
XeInstrExecution &operator()(llvm::ArrayRef<Operand *> oprs,
230+
bool onlyAttachMLIRArgs = false);
231+
232+
protected:
233+
// "Call" the instruction with operands.
234+
// \param oprs The operands of this instruction.
235+
// \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments
236+
// to the inline Asm without generating the operand ids(such as $0, $1) in
237+
// Xe code.
238+
XeInstrExecution &call(llvm::ArrayRef<Operand *> oprs,
239+
bool onlyAttachMLIRArgs = false);
240+
241+
XeBuilder *builder{};
242+
llvm::SmallVector<std::string, 4> instrParts;
243+
244+
friend struct XeInstrExecution;
245+
};
246+
247+
template <class ConcreteT> struct XeInstrBase : public XeInstrCommon {
248+
using Operand = XeBuilder::Operand;
249+
250+
explicit XeInstrBase(XeBuilder *builder, const std::string &name)
251+
: XeInstrCommon(builder) {
252+
o(name);
253+
}
254+
255+
// Append a suffix to the instruction.
256+
// e.g. XeInstr("add").o("s32") get a add.s32.
257+
// A predicate is used to tell whether to apply the suffix, so that no if-else
258+
// code needed. e.g. `XeInstr("add").o("s32", isS32).o("u32", !isS32);` will
259+
// get a `add.s32` if isS32 is true.
260+
ConcreteT &o(const std::string &suffix, bool predicate = true) {
261+
if (predicate)
262+
instrParts.push_back(suffix);
263+
return *static_cast<ConcreteT *>(this);
264+
}
265+
};
266+
267+
struct XeInstr : public XeInstrBase<XeInstr> {
268+
using XeInstrBase<XeInstr>::XeInstrBase;
269+
270+
// Append a ".global" to the instruction.
271+
XeInstr &global();
272+
273+
// Append a ".shared" to the instruction.
274+
XeInstr &shared();
275+
276+
// Append a ".v[0-9]+" to the instruction
277+
XeInstr &v(int vecWidth, bool predicate = true);
278+
279+
// Append a".b[0-9]+" to the instruction
280+
XeInstr &b(int width);
281+
};
282+
283+
// Record the operands and context for "launching" a XeInstr.
284+
struct XeInstrExecution {
285+
using Operand = XeBuilder::Operand;
286+
287+
llvm::SmallVector<Operand *> argsInOrder;
288+
289+
XeInstrExecution() = default;
290+
explicit XeInstrExecution(XeInstrCommon *instr,
291+
llvm::ArrayRef<Operand *> oprs,
292+
bool onlyAttachMLIRArgs)
293+
: argsInOrder(oprs.begin(), oprs.end()), instr(instr),
294+
onlyAttachMLIRArgs(onlyAttachMLIRArgs) {}
295+
296+
// Prefix a predicate to the instruction.
297+
XeInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
298+
pred = instr->builder->newOperand(value, constraint);
299+
return *this;
300+
}
301+
302+
// Prefix a !predicate to the instruction.
303+
XeInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
304+
pred = instr->builder->newOperand(value, constraint);
305+
pred->repr = [](int idx) { return "@!$" + std::to_string(idx); };
306+
return *this;
307+
}
308+
309+
std::string dump() const;
310+
311+
SmallVector<Operand *> getArgList() const;
312+
313+
XeInstrCommon *instr{};
314+
Operand *pred{};
315+
bool onlyAttachMLIRArgs{};
316+
};
317+
318+
} // namespace triton
319+
} // namespace mlir
320+
321+
#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_XE_ASM_FORMAT_H

third_party/intel/lib/Target/SPIRV/SPIRVTranslation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,13 @@ class SmallVectorBuffer : public std::streambuf {
107107

108108
static SPIRV::TranslatorOpts getSPIRVOopts() {
109109
SPIRV::TranslatorOpts SPIRVOpts;
110-
static constexpr std::array<SPIRV::ExtensionID, 13> AllowedExtensions{
110+
static constexpr std::array<SPIRV::ExtensionID, 14> AllowedExtensions{
111111
SPIRV::ExtensionID::SPV_EXT_shader_atomic_float_add,
112112
SPIRV::ExtensionID::SPV_INTEL_arbitrary_precision_integers,
113113
SPIRV::ExtensionID::SPV_INTEL_arithmetic_fence,
114114
SPIRV::ExtensionID::SPV_INTEL_bfloat16_conversion,
115115
SPIRV::ExtensionID::SPV_INTEL_cache_controls,
116+
SPIRV::ExtensionID::SPV_INTEL_inline_assembly,
116117
SPIRV::ExtensionID::SPV_INTEL_kernel_attributes,
117118
SPIRV::ExtensionID::SPV_INTEL_memory_access_aliasing,
118119
SPIRV::ExtensionID::SPV_INTEL_subgroups,

third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_triton_library(TritonIntelGPUToLLVM
2121
TritonOpsToLLVM.cpp
2222
TypeConverter.cpp
2323
Utility.cpp
24+
XeAsmFormat.cpp
2425

2526
DEPENDS
2627
TritonIntelGPUConversionPassIncGen

0 commit comments

Comments
 (0)