Skip to content

Commit 97d2441

Browse files
committed
Merge branch 'main' into gregory/windows-support
Signed-off-by: Gregory Shimansky <[email protected]>
2 parents 14bc5c2 + 64b232e commit 97d2441

File tree

42 files changed

+352
-1105
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+352
-1105
lines changed

.github/workflows/llvm-build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ jobs:
157157
cp -r /usr/aarch64-linux-gnu/lib ./arm-sysroot
158158
cp -r /usr/aarch64-linux-gnu/include ./arm-sysroot
159159
LINKER=$(pwd)/arm-sysroot/lib/ld-linux-aarch64.so.1
160-
wget http://ftp.de.debian.org/debian/pool/main/g/gcc-defaults/gcc-aarch64-linux-gnu_14.1.0-2_amd64.deb
161-
dpkg-deb -x gcc-aarch64-linux-gnu_14.1.0-2_amd64.deb ./arm-sysroot
160+
wget http://ftp.de.debian.org/debian/pool/main/g/gcc-defaults/gcc-aarch64-linux-gnu_14.2.0-1_amd64.deb
161+
dpkg-deb -x gcc-aarch64-linux-gnu_14.2.0-1_amd64.deb ./arm-sysroot
162162
export LD_LIBRARY_PATH=$(pwd)/arm-sysroot/lib:$LD_LIBRARY_PATH
163163
sudo ln -s $LINKER /lib/ld-linux-aarch64.so.1
164164
SYSROOT="$(pwd)/arm-sysroot"

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,19 +399,77 @@ def format_of(ty):
399399
return src
400400

401401

402+
def serialize_kernel_metadata(arg, args_dict):
403+
args_dict["num_warps"] = arg.num_warps
404+
args_dict["threads_per_warp"] = arg.threads_per_warp
405+
args_dict["shared_memory"] = arg.shared
406+
args_dict["kernel_name"] = arg.name
407+
args_dict["spv_name"] = f"{arg.name}.spv"
408+
409+
410+
def serialize_args(args, constants, signature):
411+
import numbers
412+
dir_path = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS")
413+
if not os.path.exists(dir_path):
414+
os.makedirs(dir_path)
415+
print(f"Path to directory consisting of SPIR-V Runner data: {dir_path}")
416+
417+
cnt = 0
418+
args_dict = {"gridX": args[cnt], "gridY": args[cnt + 1], "gridZ": args[cnt + 2]}
419+
args_dict["argument_list"] = []
420+
counts = {"tensors": 0, "scalars": 0, "karg_cnt": 0}
421+
cnt = 4
422+
for arg in args[cnt:]:
423+
if type(arg).__name__ == "KernelMetadata":
424+
serialize_kernel_metadata(arg, args_dict)
425+
426+
if isinstance(arg, torch.Tensor):
427+
cpu_tensor = arg.cpu()
428+
tensor_path = os.path.join(dir_path, f"tensor_{counts['tensors']}.pt")
429+
with open(tensor_path, "wb") as f:
430+
torch.save(cpu_tensor, f)
431+
new_arg = {
432+
"name": f"tensor_{counts['tensors']}", "type": "tensor", "dtype": str(arg.dtype), "ctype":
433+
signature[counts["karg_cnt"]]
434+
}
435+
args_dict["argument_list"].append(new_arg)
436+
counts["karg_cnt"] += 1
437+
counts["tensors"] += 1
438+
439+
if isinstance(arg, numbers.Number):
440+
if counts["karg_cnt"] not in constants:
441+
new_arg = {
442+
"name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": args[cnt], "ctype":
443+
signature[counts["karg_cnt"]]
444+
}
445+
args_dict["argument_list"].append(new_arg)
446+
counts["karg_cnt"] += 1
447+
counts["scalars"] += 1
448+
cnt += 1
449+
# Dump argument info as a JSON file
450+
json_path = os.path.join(dir_path, "args_data.json")
451+
with open(json_path, "w", encoding="utf-8") as json_file:
452+
import json
453+
json.dump(args_dict, json_file, indent=4)
454+
455+
402456
class XPULauncher:
403457

404458
def __init__(self, src, metadata): # pylint: disable=unused-argument
405459
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
406460
constants = src.constants if hasattr(src, "constants") else {}
407461
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
408-
constants = {cst_key(key): value for key, value in constants.items()}
409-
signature = {cst_key(key): value for key, value in src.signature.items()}
410-
src = make_launcher(constants, signature, ids)
462+
self.constants = {cst_key(key): value for key, value in constants.items()}
463+
self.signature = {cst_key(key): value for key, value in src.signature.items()}
464+
src = make_launcher(self.constants, self.signature, ids)
411465
mod = compile_module_from_src(src, "__triton_launcher")
412466
self.launch = mod.launch
413467

414468
def __call__(self, *args, **kwargs):
469+
# Serialize KernelArguments for SPIR-V Runner
470+
serialize_kernel_args = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS", None)
471+
if serialize_kernel_args:
472+
serialize_args(args, self.constants, self.signature)
415473
self.launch(*args, **kwargs)
416474

417475

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8888
mlir::registerTritonAMDGPUAccelerateMatmul();
8989
mlir::registerTritonAMDGPUOptimizeEpilogue();
9090
mlir::registerTritonAMDGPUReorderInstructions();
91-
mlir::registerTritonAMDGPUStreamPipeline();
9291
mlir::registerTritonAMDGPUStreamPipelineV2();
9392
mlir::registerTritonAMDGPUCanonicalizePointers();
9493
#endif

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
61f8a7f618901797ee8663389a29722f29216a96
1+
b5cc222d7429fe6f18c787f633d5262fac2e676f

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -460,17 +460,12 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure,
460460
If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason.
461461
The compiler is still free to change it for better performance.
462462
}];
463-
let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder, OptionalAttr<UnitAttr>:$efficient_layout);
463+
let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout);
464464
let results = (outs TT_Tensor:$result);
465-
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
465+
let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)";
466466
let hasCanonicalizeMethod = 1;
467467
let hasFolder = 1;
468468
let hasVerifier = 1;
469-
let builders = [
470-
OpBuilder<(ins "Type":$type, "Value":$src, "bool":$allow_reorder),
471-
[{
472-
build($_builder, $_state, type, src, allow_reorder, /*efficient_layout=*/UnitAttr());
473-
}]>];
474469
}
475470

476471
def TT_BroadcastOp : TT_Op<"broadcast", [Pure,

lib/Analysis/AxisInfo.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,8 +1084,9 @@ LogicalResult AxisInfoAnalysis::visitOperation(
10841084

10851085
void AxisInfoAnalysis::visitForOpInductionVar(
10861086
scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
1087-
auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue();
1088-
auto step = getLatticeElementFor(op, op.getStep())->getValue();
1087+
ProgramPoint programPoint(op);
1088+
auto lb = getLatticeElementFor(&programPoint, op.getLowerBound())->getValue();
1089+
auto step = getLatticeElementFor(&programPoint, op.getStep())->getValue();
10891090

10901091
AxisInfo::DimVectorT knownContiguity(1, 1);
10911092
AxisInfo::DimVectorT knownDivisibility(1, 1);

lib/Analysis/Utility.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -904,15 +904,16 @@ class ConstantAnalysis : public DataFlowAnalysis {
904904

905905
LogicalResult initialize(Operation *top) override {
906906
WalkResult result = top->walk([&](Operation *op) {
907-
if (failed(visit(op)))
907+
ProgramPoint programPoint(op);
908+
if (failed(visit(&programPoint)))
908909
return WalkResult::interrupt();
909910
return WalkResult::advance();
910911
});
911912
return success(!result.wasInterrupted());
912913
}
913914

914-
LogicalResult visit(ProgramPoint point) override {
915-
Operation *op = point.get<Operation *>();
915+
LogicalResult visit(ProgramPoint *point) override {
916+
Operation *op = point->getOperation();
916917
Attribute value;
917918
if (matchPattern(op, m_Constant(&value))) {
918919
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ LogicalResult canonicalizeViewOrBroadcast(OpType op,
678678
}
679679

680680
LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
681-
if (!op.getAllowReorder() || op.getEfficientLayout().has_value())
681+
if (!op.getAllowReorder() || op.getEfficientLayout())
682682
return failure();
683683
return canonicalizeViewOrBroadcast(op, rewriter);
684684
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,16 +1044,12 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
10441044
return res;
10451045
}
10461046
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
1047-
auto parentLayout = getParent();
1048-
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
1049-
if (auto distributedLayout =
1050-
mlir::dyn_cast<DistributedEncodingTrait>(parentLayout)) {
1051-
return distributedLayout.getWarpsPerCTA();
1052-
} else {
1053-
llvm::report_fatal_error(
1054-
"DotOperandEncodingAttr non-DistributedEncodingAttr parent not "
1055-
"supported yet");
1056-
}
1047+
auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent());
1048+
auto warps = distributedLayout.getWarpsPerCTA();
1049+
auto rank = warps.size();
1050+
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
1051+
warps[kDim] = 1;
1052+
return warps;
10571053
}
10581054
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10591055
return ::getWarpOrder(*this);
@@ -2764,7 +2760,7 @@ struct CanonicalizeConvertFromReshape
27642760
return failure();
27652761
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
27662762
return failure();
2767-
if (!op.getAllowReorder() || op.getEfficientLayout().has_value())
2763+
if (!op.getAllowReorder() || op.getEfficientLayout())
27682764
return failure();
27692765

27702766
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
@@ -2885,8 +2881,7 @@ struct CanonicalizeConvertFromConvert
28852881

28862882
// cvt(reshape) -> reshape
28872883
if (auto reshape = dyn_cast<ReshapeOp>(arg)) {
2888-
if (!reshape.getAllowReorder() ||
2889-
reshape.getEfficientLayout().has_value() ||
2884+
if (!reshape.getAllowReorder() || reshape.getEfficientLayout() ||
28902885
isExpensiveView(reshape.getSrc().getType(), op.getType()))
28912886
return failure();
28922887

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
77
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
8+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
89
#include "triton/Tools/LinearLayout.h"
910
#include "triton/Tools/StrUtil.h"
1011
#include "llvm/ADT/DenseMap.h"
@@ -822,16 +823,81 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
822823
return ret;
823824
}
824825

826+
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
827+
DotOperandEncodingAttr dot) {
828+
// TODO,BE. Implement ampereMMA in terms of this one
829+
int rank = shape.size();
830+
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
831+
int kWidth = dot.getKWidth();
832+
bool isA = dot.getOpIdx() == 0;
833+
834+
assert(mma.isAmpere());
835+
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
836+
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
837+
838+
MLIRContext *ctx = mma.getContext();
839+
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
840+
841+
// Implement A. For B transpose in the end
842+
std::vector<std::vector<int32_t>> registers;
843+
std::vector<std::vector<int32_t>> lanes;
844+
int32_t i = 1;
845+
// kWidth contiguous elements
846+
while (i < kWidth) {
847+
registers.push_back({i, 0});
848+
i *= 2;
849+
}
850+
// 4 threads per chunk
851+
for (int j = 0; j < 2; j++) {
852+
lanes.push_back({i, 0});
853+
i *= 2;
854+
}
855+
// 8 threads going down
856+
lanes.push_back({0, 1});
857+
lanes.push_back({0, 2});
858+
lanes.push_back({0, 4});
859+
// 2 tiles in column-major order
860+
// Just one if it's the B operand
861+
if (isA) {
862+
registers.push_back({0, 8});
863+
}
864+
registers.push_back({i, 0});
865+
866+
if (!isA) {
867+
for (auto &r : registers) {
868+
std::swap(r[0], r[1]);
869+
}
870+
for (auto &l : lanes) {
871+
std::swap(l[0], l[1]);
872+
}
873+
}
874+
875+
LinearLayout ctaLayout(
876+
{{S("register"), registers}, {S("lane"), lanes}},
877+
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));
878+
879+
auto order = dot.getCTAOrder();
880+
assert(order[0] == 1 && order[1] == 0);
881+
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
882+
883+
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
884+
}
885+
825886
std::optional<LinearLayout>
826887
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
827-
828888
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
829889
return dotOperandMfmaToLinearLayout(*this, shape);
830890
}
831891
if (auto dpasLayout = llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
832892
return dotOperandDpasToLinearLayout(*this, shape);
833893
}
834894

895+
// TODO Activate in a follow-up PR
896+
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
897+
// if (mma.isAmpere()) {
898+
// return ampereDotToLinearLayout(shape, *this);
899+
// }
900+
//}
835901
return std::nullopt;
836902
}
837903

0 commit comments

Comments
 (0)