Skip to content

Commit aa25956

Browse files
authored
Merge branch 'main' into fast-sub-group-transpose-extend
2 parents 6f74535 + 61fd54d commit aa25956

File tree

21 files changed

+470
-122
lines changed

21 files changed

+470
-122
lines changed

.github/actions/setup-pytorch/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ runs:
8282
uses: ./.github/actions/load
8383
env:
8484
# Increase this value to reset cache
85-
CACHE_NUMBER: 11
85+
CACHE_NUMBER: 12
8686
with:
8787
path: pytorch
8888
key: pytorch-$PYTORCH_CACHE_KEY-$CACHE_NUMBER

.github/pins/pytorch-upstream.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8321eec009c8c79145ebccd51fdfc336e5f8b848
1+
487873f7cafeb0fd390eaefe40496b804bceabbd

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,19 +399,77 @@ def format_of(ty):
399399
return src
400400

401401

402+
def serialize_kernel_metadata(arg, args_dict):
403+
args_dict["num_warps"] = arg.num_warps
404+
args_dict["threads_per_warp"] = arg.threads_per_warp
405+
args_dict["shared_memory"] = arg.shared
406+
args_dict["kernel_name"] = arg.name
407+
args_dict["spv_name"] = f"{arg.name}.spv"
408+
409+
410+
def serialize_args(args, constants, signature):
411+
import numbers
412+
dir_path = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS")
413+
if not os.path.exists(dir_path):
414+
os.makedirs(dir_path)
415+
print(f"Path to directory consisting of SPIR-V Runner data: {dir_path}")
416+
417+
cnt = 0
418+
args_dict = {"gridX": args[cnt], "gridY": args[cnt + 1], "gridZ": args[cnt + 2]}
419+
args_dict["argument_list"] = []
420+
counts = {"tensors": 0, "scalars": 0, "karg_cnt": 0}
421+
cnt = 4
422+
for arg in args[cnt:]:
423+
if type(arg).__name__ == "KernelMetadata":
424+
serialize_kernel_metadata(arg, args_dict)
425+
426+
if isinstance(arg, torch.Tensor):
427+
cpu_tensor = arg.cpu()
428+
tensor_path = os.path.join(dir_path, f"tensor_{counts['tensors']}.pt")
429+
with open(tensor_path, "wb") as f:
430+
torch.save(cpu_tensor, f)
431+
new_arg = {
432+
"name": f"tensor_{counts['tensors']}", "type": "tensor", "dtype": str(arg.dtype), "ctype":
433+
signature[counts["karg_cnt"]]
434+
}
435+
args_dict["argument_list"].append(new_arg)
436+
counts["karg_cnt"] += 1
437+
counts["tensors"] += 1
438+
439+
if isinstance(arg, numbers.Number):
440+
if counts["karg_cnt"] not in constants:
441+
new_arg = {
442+
"name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": args[cnt], "ctype":
443+
signature[counts["karg_cnt"]]
444+
}
445+
args_dict["argument_list"].append(new_arg)
446+
counts["karg_cnt"] += 1
447+
counts["scalars"] += 1
448+
cnt += 1
449+
# Dump argument info as a JSON file
450+
json_path = os.path.join(dir_path, "args_data.json")
451+
with open(json_path, "w", encoding="utf-8") as json_file:
452+
import json
453+
json.dump(args_dict, json_file, indent=4)
454+
455+
402456
class XPULauncher:
403457

404458
def __init__(self, src, metadata): # pylint: disable=unused-argument
405459
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
406460
constants = src.constants if hasattr(src, "constants") else {}
407461
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
408-
constants = {cst_key(key): value for key, value in constants.items()}
409-
signature = {cst_key(key): value for key, value in src.signature.items()}
410-
src = make_launcher(constants, signature, ids)
462+
self.constants = {cst_key(key): value for key, value in constants.items()}
463+
self.signature = {cst_key(key): value for key, value in src.signature.items()}
464+
src = make_launcher(self.constants, self.signature, ids)
411465
mod = compile_module_from_src(src, "__triton_launcher")
412466
self.launch = mod.launch
413467

414468
def __call__(self, *args, **kwargs):
469+
# Serialize KernelArguments for SPIR-V Runner
470+
serialize_kernel_args = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS", None)
471+
if serialize_kernel_args:
472+
serialize_args(args, self.constants, self.signature)
415473
self.launch(*args, **kwargs)
416474

417475

lib/Analysis/AxisInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::ShLIOp> {
895895
lhsDivisibility = 1;
896896
}
897897
auto numBits = log2Int(lhsDivisibility);
898-
return multiplyDivisor(lhsDivisibility, 1 << shift);
898+
return multiplyDivisor(lhsDivisibility, 1ll << shift);
899899
}
900900

901901
int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs,

lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ static void warpScan(SmallVector<SmallVector<Value>> &srcValues,
7676
acc[j] = select(mask, tempAcc[j], acc[j]);
7777
}
7878
}
79-
srcValues[srcIndex] = acc;
79+
srcValues[srcIndex] = std::move(acc);
8080
}
8181
}
8282

@@ -128,8 +128,8 @@ static void AddPartialReduce(SmallVector<SmallVector<Value>> &srcValues,
128128
ConversionPatternRewriter &rewriter,
129129
const TargetInfoBase &targetInfo,
130130
ScanLoweringHelper &helper,
131-
SmallVector<Value> smemBases,
132-
SmallVector<Type> smemTypes, Value warpId,
131+
ArrayRef<Value> smemBases,
132+
ArrayRef<Type> smemTypes, Value warpId,
133133
Value laneIdAxis, Value parallelLaneId) {
134134
Location loc = helper.getLoc();
135135
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
@@ -224,7 +224,7 @@ static void AddPartialReduce(SmallVector<SmallVector<Value>> &srcValues,
224224
srcValues[srcIndex - i * elementStride][j]);
225225
}
226226
}
227-
srcValues[srcIndex - i * elementStride] = laneValue;
227+
srcValues[srcIndex - i * elementStride] = std::move(laneValue);
228228
}
229229
// For the next chunk start back from the value containing the
230230
// accumulated value of all the warps.
@@ -303,7 +303,7 @@ static void AddPartialReduceOneWarp(SmallVector<SmallVector<Value>> &srcValues,
303303
srcValues[srcIndex - i * elementStride][j], laneValue[j]);
304304
}
305305
}
306-
srcValues[srcIndex - i * elementStride] = laneValue;
306+
srcValues[srcIndex - i * elementStride] = std::move(laneValue);
307307
}
308308
// For the next chunk start back from the value containing the
309309
// accumulated value of all the warps.

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,16 +1044,16 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
10441044
return res;
10451045
}
10461046
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
1047-
auto parentLayout = getParent();
1048-
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
1049-
if (auto distributedLayout =
1050-
mlir::dyn_cast<DistributedEncodingTrait>(parentLayout)) {
1051-
return distributedLayout.getWarpsPerCTA();
1052-
} else {
1053-
llvm::report_fatal_error(
1054-
"DotOperandEncodingAttr non-DistributedEncodingAttr parent not "
1055-
"supported yet");
1056-
}
1047+
auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent());
1048+
auto warps = distributedLayout.getWarpsPerCTA();
1049+
// FIXME: This is a temporary solution to avoid distribute-to-warps.mlir
1050+
// failure.
1051+
if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH"))
1052+
return warps;
1053+
auto rank = warps.size();
1054+
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
1055+
warps[kDim] = 1;
1056+
return warps;
10571057
}
10581058
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10591059
return ::getWarpOrder(*this);

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
77
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
8+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
89
#include "triton/Tools/LinearLayout.h"
910
#include "triton/Tools/StrUtil.h"
1011
#include "llvm/ADT/DenseMap.h"
@@ -822,16 +823,81 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
822823
return ret;
823824
}
824825

826+
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
827+
DotOperandEncodingAttr dot) {
828+
// TODO,BE. Implement ampereMMA in terms of this one
829+
int rank = shape.size();
830+
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
831+
int kWidth = dot.getKWidth();
832+
bool isA = dot.getOpIdx() == 0;
833+
834+
assert(mma.isAmpere());
835+
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
836+
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
837+
838+
MLIRContext *ctx = mma.getContext();
839+
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
840+
841+
// Implement A. For B transpose in the end
842+
std::vector<std::vector<int32_t>> registers;
843+
std::vector<std::vector<int32_t>> lanes;
844+
int32_t i = 1;
845+
// kWidth contiguous elements
846+
while (i < kWidth) {
847+
registers.push_back({i, 0});
848+
i *= 2;
849+
}
850+
// 4 threads per chunk
851+
for (int j = 0; j < 2; j++) {
852+
lanes.push_back({i, 0});
853+
i *= 2;
854+
}
855+
// 8 threads going down
856+
lanes.push_back({0, 1});
857+
lanes.push_back({0, 2});
858+
lanes.push_back({0, 4});
859+
// 2 tiles in column-major order
860+
// Just one if it's the B operand
861+
if (isA) {
862+
registers.push_back({0, 8});
863+
}
864+
registers.push_back({i, 0});
865+
866+
if (!isA) {
867+
for (auto &r : registers) {
868+
std::swap(r[0], r[1]);
869+
}
870+
for (auto &l : lanes) {
871+
std::swap(l[0], l[1]);
872+
}
873+
}
874+
875+
LinearLayout ctaLayout(
876+
{{S("register"), registers}, {S("lane"), lanes}},
877+
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));
878+
879+
auto order = dot.getCTAOrder();
880+
assert(order[0] == 1 && order[1] == 0);
881+
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
882+
883+
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
884+
}
885+
825886
std::optional<LinearLayout>
826887
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
827-
828888
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
829889
return dotOperandMfmaToLinearLayout(*this, shape);
830890
}
831891
if (auto dpasLayout = llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
832892
return dotOperandDpasToLinearLayout(*this, shape);
833893
}
834894

895+
// TODO Activate in a follow-up PR
896+
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
897+
// if (mma.isAmpere()) {
898+
// return ampereDotToLinearLayout(shape, *this);
899+
// }
900+
//}
835901
return std::nullopt;
836902
}
837903

python/test/unit/runtime/test_bindings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def walk_fn(op):
5959
torch.empty((32, 32), device=device), # out_ptr
6060
16, # BLOCK_SIZE
6161
]
62+
target = triton.runtime.driver.active.get_current_target()
63+
backend = triton.compiler.compiler.make_backend(target)
6264
src = triton.compiler.compiler.ASTSource(
6365
fn=kernel,
6466
signature={
@@ -69,12 +71,10 @@ def walk_fn(op):
6971
constants={kernel.arg_names[i]: arg
7072
for i, arg in enumerate(args)
7173
if not isinstance(arg, torch.Tensor)},
72-
attrs=kernel._get_config(*args, ),
74+
attrs=backend.get_attrs_descriptor(args, kernel.params),
7375
)
7476

7577
context = triton._C.libtriton.ir.context()
76-
target = triton.runtime.driver.active.get_current_target()
77-
backend = triton.compiler.compiler.make_backend(target)
7878
options = backend.parse_options(dict())
7979
codegen_fns = dict()
8080
module_map = backend.get_module_map()

python/test/unit/runtime/test_subproc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import triton
55
import triton.language as tl
6+
from triton.backends.compiler import AttrsDescriptor
67
from triton.compiler import ASTSource
78

89
target = triton.runtime.driver.active.get_current_target()
@@ -25,7 +26,7 @@ def kernel_sub(a, b, o, N: tl.constexpr):
2526

2627

2728
def test_compile_in_subproc() -> None:
28-
config = triton.compiler.AttrsDescriptor(tuple(range(4)), ())
29+
config = AttrsDescriptor.from_hints({i: 16 for i in range(4)})
2930
multiprocessing.set_start_method('fork')
3031
proc = multiprocessing.Process(target=compile_fn, args=(config, ))
3132
proc.start()
@@ -47,7 +48,7 @@ def kernel_dot(Z):
4748

4849

4950
def test_compile_in_forked_subproc(fresh_triton_cache) -> None:
50-
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())
51+
config = AttrsDescriptor.from_hints({0: 16})
5152
assert multiprocessing.get_start_method() == 'fork'
5253
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, ))
5354
proc.start()
@@ -86,7 +87,7 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
8687
gc.disable()
8788

8889
# stage 1.p
89-
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())
90+
config = AttrsDescriptor.from_hints({0: 16})
9091
compile_empty_kernel_with_gc(config)
9192

9293
# stage 2.p

0 commit comments

Comments
 (0)