Skip to content

Commit cd276bf

Browse files
authored
Merge branch 'main' into lesh/conda-oct
2 parents 8d42079 + 0ba3707 commit cd276bf

File tree

24 files changed

+1009
-238
lines changed

24 files changed

+1009
-238
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
name: Triton benchmarks
2+
run-name: ${{ inputs.run_name }}
23

34
on:
45
workflow_dispatch:
@@ -19,6 +20,10 @@ on:
1920
- ELAPSED_TIME
2021
- UPSTREAM_PYTORCH_PROFILER
2122
default: PYTORCH_LEGACY_PROFILER_USING_IPEX
23+
run_name:
24+
description: Run name
25+
type: string
26+
default: "Triton benchmarks"
2227
schedule:
2328
- cron: "5 23 * * *"
2429
pull_request:
@@ -248,7 +253,6 @@ jobs:
248253
run: |
249254
cd benchmarks/triton_kernels_benchmark
250255
TRITON_INTEL_ADVANCED_PATH=1 \
251-
TRITON_INTEL_ENABLE_INSTR_SCHED=1 \
252256
TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 \
253257
IGC_VISAOptions=" -enableBCR" \
254258
python flash_attention_fwd_benchmark.py --reports $REPORTS

benchmarks/setup.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,37 @@ def run(self):
125125
super().run()
126126

127127

128-
setup(name="triton-kernels-benchmark", packages=[
129-
"triton_kernels_benchmark",
130-
], package_dir={
131-
"triton_kernels_benchmark": "triton_kernels_benchmark",
132-
}, package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]}, cmdclass={
133-
"build_ext": build_ext,
134-
"clean": clean,
135-
}, ext_modules=[CMakeExtension("triton_kernels_benchmark")])
128+
def get_git_commit_hash(length=8):
129+
try:
130+
cmd = ["git", "rev-parse", f"--short={length}", "HEAD"]
131+
return f"+git{subprocess.check_output(cmd).strip().decode('utf-8')}"
132+
except (
133+
FileNotFoundError,
134+
subprocess.CalledProcessError,
135+
subprocess.TimeoutExpired,
136+
):
137+
return ""
138+
139+
140+
setup(
141+
name="triton-kernels-benchmark",
142+
version="3.1.0" + get_git_commit_hash(),
143+
packages=["triton_kernels_benchmark"],
144+
install_requires=[
145+
"torch",
146+
"pandas",
147+
"tabulate",
148+
"matplotlib",
149+
],
150+
package_dir={"triton_kernels_benchmark": "triton_kernels_benchmark"},
151+
package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]},
152+
cmdclass={
153+
"build_ext": build_ext,
154+
"clean": clean,
155+
},
156+
ext_modules=[CMakeExtension("triton_kernels_benchmark")],
157+
extra_require={
158+
"ipex": ["numpy<=2.0", "intel-extension-for-pytorch=2.1.10"],
159+
"pytorch": ["torch>=2.6"],
160+
},
161+
)

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan
153153
warmup_time = n_warmup * estimate_ms
154154
rep_time = n_repeat * estimate_ms
155155

156-
times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all",
157-
device_type=device)
156+
times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all")
158157
times = torch.tensor(times, dtype=torch.float)
159158
return _summarize_statistics(times, quantiles, return_mode)
160159

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,4 @@ add_triton_library(TritonGPUToLLVM
3535
TritonGPUTransforms
3636
TritonIntelGPUTransforms
3737
TritonNvidiaGPUTransforms
38-
NVGPUIR
3938
)

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,36 +41,60 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
4141
if (inBitWidth == ouBitWidth)
4242
return values;
4343
if (inBitWidth == 16 && ouBitWidth == 32) {
44+
// Register layout conversion:
45+
//
46+
// [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
47+
// [2, 3], [6, 7] [2], [3], [6], [7]
48+
//
49+
// Original access order:
50+
//
51+
// [0, 1], [2, 3], [4, 5], [6, 7]
52+
//
53+
// Transformed access order:
54+
//
55+
// [0], [2], [1], [3], [4], [6], [5], [7]
4456
SmallVector<Value> ret;
4557
for (unsigned i = 0; i < values.size(); i += 8) {
4658
ret.push_back(values[i]);
47-
ret.push_back(values[i + 1]);
48-
ret.push_back(values[i + 4]);
49-
ret.push_back(values[i + 5]);
5059
ret.push_back(values[i + 2]);
60+
ret.push_back(values[i + 1]);
5161
ret.push_back(values[i + 3]);
62+
ret.push_back(values[i + 4]);
5263
ret.push_back(values[i + 6]);
64+
ret.push_back(values[i + 5]);
5365
ret.push_back(values[i + 7]);
5466
}
5567
return ret;
5668
}
5769
if (inBitWidth == 8 && ouBitWidth == 16) {
70+
// Register layout conversion:
71+
//
72+
// [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]
73+
// [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15]
74+
//
75+
// Original access order:
76+
//
77+
// [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]
78+
//
79+
// Transformed access order:
80+
//
81+
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
5882
SmallVector<Value> ret;
5983
for (unsigned i = 0; i < values.size(); i += 16) {
60-
ret.push_back(values[i + 0]);
84+
ret.push_back(values[i]);
6185
ret.push_back(values[i + 1]);
62-
ret.push_back(values[i + 2]);
63-
ret.push_back(values[i + 3]);
64-
ret.push_back(values[i + 8]);
65-
ret.push_back(values[i + 9]);
66-
ret.push_back(values[i + 10]);
67-
ret.push_back(values[i + 11]);
6886
ret.push_back(values[i + 4]);
6987
ret.push_back(values[i + 5]);
88+
ret.push_back(values[i + 2]);
89+
ret.push_back(values[i + 3]);
7090
ret.push_back(values[i + 6]);
7191
ret.push_back(values[i + 7]);
92+
ret.push_back(values[i + 8]);
93+
ret.push_back(values[i + 9]);
7294
ret.push_back(values[i + 12]);
7395
ret.push_back(values[i + 13]);
96+
ret.push_back(values[i + 10]);
97+
ret.push_back(values[i + 11]);
7498
ret.push_back(values[i + 14]);
7599
ret.push_back(values[i + 15]);
76100
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,11 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
238238
}
239239

240240
SmallVector<unsigned> getWarpOrder(Attribute layout) {
241+
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
242+
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
243+
return getWarpOrder(dotLayout.getParent());
244+
}
245+
}
241246
auto order = getOrder(layout);
242247
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
243248
if (mmaLayout.isHopper()) {

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
473473
int nIndex = 1 + hasBatchDim;
474474
(void)mIndex, (void)nIndex;
475475

476-
assert(((shape[mIndex] == 1 || shape[mIndex] >= getMDim()) &&
477-
(shape[nIndex] == 1 || shape[nIndex] >= getNDim())) &&
478-
"Unsupported tensor shape for given mfma layout");
479-
480476
assert(((getMDim() == 32 && getNDim() == 32) ||
481477
(getMDim() == 16 && getNDim() == 16)) &&
482478
"Unsupported mfma type");
@@ -580,55 +576,76 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
580576
// 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is
581577
// held by exactly one thread, maintaining the same number of global loads
582578
// as in a blocked layout.
579+
//
580+
// Other use of Linear layout is a support of rare corner cases,
581+
// for example one instruction tile is larger than tensor
583582
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
584583

585-
if (dotMfmaLayout.getOpIdx() == 0) {
586-
return std::nullopt;
587-
}
588584
auto rank = shape.size();
589585
bool hasBatchDim = rank == 3;
590586
int mIndex = 0 + hasBatchDim;
591587

592-
auto kWidth = dotMfmaLayout.getKWidth();
588+
int32_t kWidth = dotMfmaLayout.getKWidth();
589+
auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
590+
int32_t kSize = shape[kDim];
593591
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
594592

595-
if (kWidth != 8 || warpsPerCTA[mIndex] != 1) {
596-
return std::nullopt;
597-
}
598-
599593
MLIRContext *ctx = dotMfmaLayout.getContext();
600594
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
601595

602596
StringAttr kRegister = S("register");
603597
StringAttr kLane = S("lane");
598+
StringAttr kWarp = S("warp");
604599

600+
// register order
601+
// operand A: [1, 0] / [2, 1, 0]
602+
// operand B: [0, 1] / [1, 2, 0]
603+
// for both cases it is [k, nonk]/[k, nonk, batch]
605604
SmallVector<unsigned> order = triton::gpu::getOrder(dotMfmaLayout);
606-
auto tileLayout = LinearLayout::empty();
605+
// warp order
606+
// common for both operand A and B: [0, 1] / [0, 1, 2]
607+
// in both cases it is [M dim, N dim]/[batch, M dim, N dim]
608+
SmallVector<unsigned> warpOrder = triton::gpu::getWarpOrder(dotMfmaLayout);
609+
610+
// Lane holds kWidth consecutive elements along k dimension, so
611+
// base register vectors for one tile are initialized in following way:
612+
// {1, 0}, {2, 0} ... {kWidth/2, 0}
613+
std::vector<std::vector<int32_t>> registerBase;
614+
for (int32_t elem = 1; elem < kWidth; elem *= 2)
615+
registerBase.emplace_back(std::vector<int32_t>{elem, 0});
616+
617+
std::vector<std::vector<int32_t>> laneBase;
618+
int32_t kTileSize = -1;
607619

608620
if (mfmaLayout.getMDim() == 32) {
609-
// Based on canonical MFMA linear layout, which handles 4 consecutive
610-
// elements along the register dimension, kWidth=8 means we have 8
611-
// consecutive elements, so we have an additional {4, 0} base vector here.
612-
// For lane dim, since the MFMA thread arrangement is {K, N} = {2, 32}, this
613-
// means that mapping of first 5 base (up to thread 16) vectors will be an
614-
// identity along N dim. Thread 32 will be mapped to element 8 in K
615-
// dimension, because kWidth == 8.
616-
tileLayout = LinearLayout(
617-
{{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
618-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}},
619-
{outDimNames[order[0]], outDimNames[order[1]]});
621+
// Canonical MFMA linear layout handles 4 consecutive elements along
622+
// the register dimension. Dot operand handles varaible kWidth consecutive
623+
// elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2,
624+
// 32}, this means that mapping of first 5 base (up to thread 16) vectors
625+
// will be an identity along N dim. Thread 32 will be mapped to element
626+
// kWidth in K dimension.
627+
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {kWidth, 0}};
628+
kTileSize = kWidth * 2;
620629
} else {
621630
assert(mfmaLayout.getMDim() == 16);
622631
// For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this
623632
// means that mapping of first 4 base (up to thread 16) vectors will be an
624-
// identity along N dim. Thread 16 will be mapped to element 8 in K
625-
// dimension, because kWidth == 8. Thread 32 is mapped to element 16 as that
626-
// is 2*kWidth in K dim.
627-
tileLayout = LinearLayout(
628-
{{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
629-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}},
630-
{outDimNames[order[0]], outDimNames[order[1]]});
633+
// identity along N dim. Thread 16 will be mapped to element kWisth in K
634+
// dimension. Thread 32 is mapped to element 2*kWidth in K dim.
635+
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {kWidth, 0}, {kWidth * 2, 0}};
636+
kTileSize = kWidth * 4;
631637
}
638+
assert(kTileSize != -1);
639+
// Add repeats of registers along K dimension to register base vectors
640+
for (int32_t elem = kTileSize; elem < kSize; elem *= 2)
641+
registerBase.emplace_back(std::vector<int32_t>{elem, 0});
642+
643+
// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
644+
// To assign them to actual matrix dimensions `order` array is used.
645+
// For operand A: non-k-dim -> dim0, k-dim -> dim1
646+
// For operand B: non-k-dim -> dim1, k-dim -> dim0
647+
LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}},
648+
{outDimNames[order[0]], outDimNames[order[1]]});
632649

633650
if (hasBatchDim) {
634651
assert(order[2] == 0);
@@ -639,8 +656,10 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
639656
}
640657

641658
LinearLayout warpLayout =
642-
identityND(S("warp"), warpsPerCTA, order, outDimNames);
643-
LinearLayout ctaLayout = tileLayout * warpLayout;
659+
identityND(kWarp, warpsPerCTA, warpOrder, outDimNames);
660+
661+
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
662+
warpLayout.transposeOuts(outDimNames);
644663

645664
return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
646665
}
@@ -1001,6 +1020,8 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
10011020
mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
10021021
if (!mmaLayout || !mmaLayout.isHopper())
10031022
return false;
1023+
if (isa<PointerType>(tensorTy.getElementType()))
1024+
return false;
10041025
if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16)
10051026
return false;
10061027
if (order[0] != 1)

python/triton/testing.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
139139
return _summarize_statistics(torch.tensor(ret), quantiles, return_mode)
140140

141141

142-
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device_type="xpu"):
142+
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
143143
"""
144144
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
145145
the 20-th and 80-th performance percentile.
@@ -164,11 +164,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
164164
fn()
165165
di.synchronize()
166166

167-
# We maintain a buffer of 256 MB that we clear
168-
# before each kernel call to make sure that the L2 cache
169-
# doesn't contain any input data before the run
170-
cache_size = 256 * 1024 * 1024
171-
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device_type)
167+
cache = runtime.driver.active.get_empty_cache_for_benchmark()
172168

173169
# Estimate the runtime of the function
174170
start_event = Event(enable_timing=True)

scripts/test-triton.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ run_benchmark_attention() {
290290
echo "Advanced path:"
291291
TRITON_INTEL_ADVANCED_PATH=1 \
292292
TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 \
293-
TRITON_INTEL_ENABLE_INSTR_SCHED=1 \
294293
IGC_VISAOptions=" -enableBCR" \
295294
python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py
296295
}

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,7 +1034,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
10341034
// -----
10351035

10361036
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1037-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1037+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
10381038
// CHECK-LABEL: atomic_add_f32
10391039
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
10401040
// CHECK: llvm.inline_asm
@@ -1048,7 +1048,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
10481048

10491049
// -----
10501050

1051-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1051+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
10521052
// CHECK-LABEL: atomic_add_f32_scalar
10531053
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
10541054
// CHECK: llvm.icmp "eq"
@@ -1062,7 +1062,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
10621062
// -----
10631063

10641064
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1065-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1065+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
10661066
// CHECK-LABEL: atomic_add_f32
10671067
tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
10681068
// CHECK: llvm.inline_asm
@@ -1076,6 +1076,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
10761076

10771077
// -----
10781078

1079+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1080+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
1081+
tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} {
1082+
// CHECK-LABEL: atomic_add_f16_nomask
1083+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
1084+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
1085+
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>) -> tensor<256xf16, #blocked>
1086+
tt.return
1087+
}
1088+
}
1089+
1090+
// -----
1091+
1092+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1093+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
1094+
tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked>) attributes {noinline = false} {
1095+
// CHECK-LABEL: atomic_add_f16_withmask
1096+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
1097+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
1098+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
1099+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
1100+
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked>
1101+
tt.return
1102+
}
1103+
}
1104+
1105+
// -----
1106+
10791107
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
10801108
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
10811109
// CHECK-LABEL: store_f32

0 commit comments

Comments
 (0)