Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan
warmup_time = n_warmup * estimate_ms
rep_time = n_repeat * estimate_ms

times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all",
device_type=device)
times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all")
times = torch.tensor(times, dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)

Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,4 @@ add_triton_library(TritonGPUToLLVM
TritonGPUTransforms
TritonIntelGPUTransforms
TritonNvidiaGPUTransforms
NVGPUIR
)
44 changes: 34 additions & 10 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,36 +41,60 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
// Register layout conversion:
//
// [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
// [2, 3], [6, 7] [2], [3], [6], [7]
//
// Original access order:
//
// [0, 1], [2, 3], [4, 5], [6, 7]
//
// Transformed access order:
//
// [0], [2], [1], [3], [4], [6], [5], [7]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 7]);
}
return ret;
}
if (inBitWidth == 8 && ouBitWidth == 16) {
// Register layout conversion:
//
// [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]
// [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15]
//
// Original access order:
//
// [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]
//
// Transformed access order:
//
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i + 0]);
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 10]);
ret.push_back(values[i + 11]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 12]);
ret.push_back(values[i + 13]);
ret.push_back(values[i + 10]);
ret.push_back(values[i + 11]);
ret.push_back(values[i + 14]);
ret.push_back(values[i + 15]);
}
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
}

SmallVector<unsigned> getWarpOrder(Attribute layout) {
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
return getWarpOrder(dotLayout.getParent());
}
}
auto order = getOrder(layout);
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (mmaLayout.isHopper()) {
Expand Down
87 changes: 54 additions & 33 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
int nIndex = 1 + hasBatchDim;
(void)mIndex, (void)nIndex;

assert(((shape[mIndex] == 1 || shape[mIndex] >= getMDim()) &&
(shape[nIndex] == 1 || shape[nIndex] >= getNDim())) &&
"Unsupported tensor shape for given mfma layout");

assert(((getMDim() == 32 && getNDim() == 32) ||
(getMDim() == 16 && getNDim() == 16)) &&
"Unsupported mfma type");
Expand Down Expand Up @@ -580,55 +576,76 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
// 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is
// held by exactly one thread, maintaining the same number of global loads
// as in a blocked layout.
//
// Other use of Linear layout is a support of rare corner cases,
// for example one instruction tile is larger than tensor
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());

if (dotMfmaLayout.getOpIdx() == 0) {
return std::nullopt;
}
auto rank = shape.size();
bool hasBatchDim = rank == 3;
int mIndex = 0 + hasBatchDim;

auto kWidth = dotMfmaLayout.getKWidth();
int32_t kWidth = dotMfmaLayout.getKWidth();
auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
int32_t kSize = shape[kDim];
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();

if (kWidth != 8 || warpsPerCTA[mIndex] != 1) {
return std::nullopt;
}

MLIRContext *ctx = dotMfmaLayout.getContext();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);

StringAttr kRegister = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");

// register order
// operand A: [1, 0] / [2, 1, 0]
// operand B: [0, 1] / [1, 2, 0]
// for both cases it is [k, nonk]/[k, nonk, batch]
SmallVector<unsigned> order = triton::gpu::getOrder(dotMfmaLayout);
auto tileLayout = LinearLayout::empty();
// warp order
// common for both operand A and B: [0, 1] / [0, 1, 2]
// in both cases it is [M dim, N dim]/[batch, M dim, N dim]
SmallVector<unsigned> warpOrder = triton::gpu::getWarpOrder(dotMfmaLayout);

// Lane holds kWidth consecutive elements along k dimension, so
// base register vectors for one tile are initialized in following way:
// {1, 0}, {2, 0} ... {kWidth/2, 0}
std::vector<std::vector<int32_t>> registerBase;
for (int32_t elem = 1; elem < kWidth; elem *= 2)
registerBase.emplace_back(std::vector<int32_t>{elem, 0});

std::vector<std::vector<int32_t>> laneBase;
int32_t kTileSize = -1;

if (mfmaLayout.getMDim() == 32) {
// Based on canonical MFMA linear layout, which handles 4 consecutive
// elements along the register dimension, kWidth=8 means we have 8
// consecutive elements, so we have an additional {4, 0} base vector here.
// For lane dim, since the MFMA thread arrangement is {K, N} = {2, 32}, this
// means that mapping of first 5 base (up to thread 16) vectors will be an
// identity along N dim. Thread 32 will be mapped to element 8 in K
// dimension, because kWidth == 8.
tileLayout = LinearLayout(
{{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
// Canonical MFMA linear layout handles 4 consecutive elements along
// the register dimension. Dot operand handles varaible kWidth consecutive
// elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2,
// 32}, this means that mapping of first 5 base (up to thread 16) vectors
// will be an identity along N dim. Thread 32 will be mapped to element
// kWidth in K dimension.
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {kWidth, 0}};
kTileSize = kWidth * 2;
} else {
assert(mfmaLayout.getMDim() == 16);
// For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this
// means that mapping of first 4 base (up to thread 16) vectors will be an
// identity along N dim. Thread 16 will be mapped to element 8 in K
// dimension, because kWidth == 8. Thread 32 is mapped to element 16 as that
// is 2*kWidth in K dim.
tileLayout = LinearLayout(
{{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}},
{outDimNames[order[0]], outDimNames[order[1]]});
// identity along N dim. Thread 16 will be mapped to element kWisth in K
// dimension. Thread 32 is mapped to element 2*kWidth in K dim.
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {kWidth, 0}, {kWidth * 2, 0}};
kTileSize = kWidth * 4;
}
assert(kTileSize != -1);
// Add repeats of registers along K dimension to register base vectors
for (int32_t elem = kTileSize; elem < kSize; elem *= 2)
registerBase.emplace_back(std::vector<int32_t>{elem, 0});

// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
// To assign them to actual matrix dimensions `order` array is used.
// For operand A: non-k-dim -> dim0, k-dim -> dim1
// For operand B: non-k-dim -> dim1, k-dim -> dim0
LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}},
{outDimNames[order[0]], outDimNames[order[1]]});

if (hasBatchDim) {
assert(order[2] == 0);
Expand All @@ -639,8 +656,10 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
}

LinearLayout warpLayout =
identityND(S("warp"), warpsPerCTA, order, outDimNames);
LinearLayout ctaLayout = tileLayout * warpLayout;
identityND(kWarp, warpsPerCTA, warpOrder, outDimNames);

LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
warpLayout.transposeOuts(outDimNames);

return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
}
Expand Down Expand Up @@ -1001,6 +1020,8 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
if (!mmaLayout || !mmaLayout.isHopper())
return false;
if (isa<PointerType>(tensorTy.getElementType()))
return false;
if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16)
return false;
if (order[0] != 1)
Expand Down
8 changes: 2 additions & 6 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
return _summarize_statistics(torch.tensor(ret), quantiles, return_mode)


def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device_type="xpu"):
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
Expand All @@ -164,11 +164,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
fn()
di.synchronize()

# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2 cache
# doesn't contain any input data before the run
cache_size = 256 * 1024 * 1024
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device_type)
cache = runtime.driver.active.get_empty_cache_for_benchmark()

# Estimate the runtime of the function
start_event = Event(enable_timing=True)
Expand Down
34 changes: 31 additions & 3 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// -----

#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
// CHECK-LABEL: atomic_add_f32
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
Expand All @@ -1048,7 +1048,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :

// -----

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
// CHECK-LABEL: atomic_add_f32_scalar
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
// CHECK: llvm.icmp "eq"
Expand All @@ -1062,7 +1062,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// -----

#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
// CHECK-LABEL: atomic_add_f32
tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
Expand All @@ -1076,6 +1076,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f16_nomask
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>) -> tensor<256xf16, #blocked>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked>) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f16_withmask
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked>
tt.return
}
}

// -----

#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32
Expand Down
38 changes: 38 additions & 0 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,41 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr<f32>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f32_nomask
// CHECK: atom.global.gpu.acq_rel.add.v4.f32
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr<f32>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f32_withmask
// CHECK: atom.global.gpu.acq_rel.add.v2.f32
// CHECK: atom.global.gpu.acq_rel.add.v2.f32
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f16_withmask
// CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16
// CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked>
tt.return
}
}
7 changes: 7 additions & 0 deletions third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,10 @@ def get_current_target(self):
def get_benchmarker(self):
from triton.testing import do_bench
return do_bench

def get_empty_cache_for_benchmark(self):
import torch

# It's the same as the Nvidia backend.
cache_size = 256 * 1024 * 1024
return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
Loading