Skip to content

Commit 184f6db

Browse files
Merge OpenAI Triton commit d57cbee (#4453)
This PR change the Triton base from 4dfdc32 to d57cbee (Jun 6). Pass rate: 97.23%->97.2%
2 parents 9d7bc59 + dee2d72 commit 184f6db

File tree

58 files changed

+4613
-3996
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+4613
-3996
lines changed

.github/workflows/integration-tests-nvidia.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ jobs:
8888
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
8989
source /venv/bin/activate
9090
fi
91-
make test-unit
91+
make NUM_PROCS=24 test-unit
9292
- name: Run interpreter tests
9393
if: ${{ matrix.runner[0] == 'nvidia-h100' }}
9494
run: make test-interpret

Makefile

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ BUILD_DIR := $(shell cd python; $(PYTHON) -c 'from build_helpers import get_cmak
77
TRITON_OPT := $(BUILD_DIR)/bin/triton-opt
88
PYTEST := $(PYTHON) -m pytest
99
LLVM_BUILD_PATH ?= ".llvm-project/build"
10+
NUM_PROCS ?= 8
1011

1112
# Incremental builds
1213

@@ -30,25 +31,25 @@ test-cpp:
3031

3132
.PHONY: test-unit
3233
test-unit: all
33-
cd python/test/unit && $(PYTEST) -s -n 8 --ignore=language/test_line_info.py \
34+
cd python/test/unit && $(PYTEST) -s -n $(NUM_PROCS) --ignore=language/test_line_info.py \
3435
--ignore=language/test_subprocess.py --ignore=test_debug.py
35-
$(PYTEST) -s -n 8 python/test/unit/language/test_subprocess.py
36-
$(PYTEST) -s -n 8 python/test/unit/test_debug.py --forked
36+
$(PYTEST) -s -n $(NUM_PROCS) python/test/unit/language/test_subprocess.py
37+
$(PYTEST) -s -n $(NUM_PROCS) python/test/unit/test_debug.py --forked
3738
$(PYTEST) -s -n 8 python/triton_kernels/tests/
3839
TRITON_DISABLE_LINE_INFO=0 $(PYTEST) -s python/test/unit/language/test_line_info.py
3940
# Run attention separately to avoid out of gpu memory
4041
$(PYTEST) -vs python/tutorials/06-fused-attention.py
4142
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4243
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
43-
$(PYTEST) -s -n 8 python/test/gluon
44+
$(PYTEST) -s -n $(NUM_PROCS) python/test/gluon
4445

4546
.PHONY: test-gluon
4647
test-gluon: all
47-
$(PYTEST) -s -n 8 python/test/gluon
48+
$(PYTEST) -s -n $(NUM_PROCS) python/test/gluon
4849

4950
.PHONY: test-regression
5051
test-regression: all
51-
$(PYTEST) -s -n 8 python/test/regression
52+
$(PYTEST) -s -n $(NUM_PROCS) python/test/regression
5253

5354
.PHONY: test-interpret
5455
test-interpret: all

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,13 @@ void replaceUsesWithLocalLoad(
263263
OpBuilder &builder, OpResult old,
264264
TypedValue<triton::gpu::MemDescType> alloc,
265265
TypedValue<triton::gpu::AsyncTokenType> token = {});
266+
267+
// Return true if the value comes from a load or a block argument.
268+
// This will skip convert layouts and memdesc views.
269+
// This is a helper useful to know if value is likely to come from shared memory
270+
// after converting loads into async loads.
271+
bool comesFromLoadOrBlockArg(Value v);
272+
266273
} // namespace mlir::triton
267274

268275
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,8 @@ OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) {
368368
LogicalResult MakeRangeOp::verify() {
369369
int64_t start = getStartAttr().getInt();
370370
int64_t end = getEndAttr().getInt();
371-
if (start > end) {
372-
return this->emitOpError() << "start must be less than or equal to end";
371+
if (start >= end) {
372+
return this->emitOpError() << "start must be less than end";
373373
}
374374
auto ty = getType();
375375
if (ty.getShape().size() != 1) {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -322,29 +322,6 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
322322
if (!(versionMajor >= 1 && versionMajor <= 3))
323323
return failure();
324324

325-
// If both of the operands are not loads, we fallback to MMAv2
326-
// otherwise the reg-smem roundtrip will tank the MMAv3 performance
327-
auto comesFromLoadOrBlockArg = [](Value v) -> bool {
328-
// Peel out the original cvt dot_op<..., #blocked>
329-
// and any other potential cvt/trans ops
330-
while (true) {
331-
if (auto cvtOp = v.getDefiningOp<ConvertLayoutOp>()) {
332-
v = cvtOp.getSrc();
333-
continue;
334-
}
335-
if (auto transOp = v.getDefiningOp<TransOp>()) {
336-
v = transOp.getSrc();
337-
continue;
338-
}
339-
break;
340-
}
341-
// We also accept block arguments as they appear in many MLIR tests
342-
// If this is problematic we can totally drop them
343-
return isa<BlockArgument>(v) ||
344-
(v.getDefiningOp() &&
345-
isa<LoadOp, DescriptorLoadOp>(v.getDefiningOp()));
346-
};
347-
348325
bool aFromLoad = comesFromLoadOrBlockArg(dotOp.getA());
349326
bool bFromLoad = comesFromLoadOrBlockArg(dotOp.getB());
350327
auto origDotOp = dotOp;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,4 +1554,29 @@ void replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old,
15541554
alloc.erase();
15551555
}
15561556
}
1557+
1558+
bool comesFromLoadOrBlockArg(Value v) {
1559+
// Peel out the original cvt dot_op<..., #blocked>
1560+
// and any other potential cvt/trans ops
1561+
while (true) {
1562+
Operation *def = v.getDefiningOp();
1563+
if (!def)
1564+
break;
1565+
if (auto cvtOp = dyn_cast<ttg::ConvertLayoutOp>(def)) {
1566+
v = cvtOp.getSrc();
1567+
continue;
1568+
}
1569+
if (def->hasTrait<OpTrait::MemDescViewTrait>()) {
1570+
v = def->getOperand(0);
1571+
continue;
1572+
}
1573+
break;
1574+
}
1575+
// We also accept block arguments as they appear in many MLIR tests
1576+
// If this is problematic we can totally drop them
1577+
return isa<BlockArgument>(v) ||
1578+
(v.getDefiningOp() &&
1579+
isa<LoadOp, DescriptorLoadOp, DescriptorGatherOp>(v.getDefiningOp()));
1580+
}
1581+
15571582
} // namespace mlir::triton

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,17 @@ static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
226226
return std::nullopt;
227227

228228
// Propagate defs of exp.
229-
for (auto expOp : loop.getOps<math::Exp2Op>()) {
230-
auto tensorTy = dyn_cast<RankedTensorType>(expOp.getType());
231-
if (tensorTy && tensorTy.getNumElements() > 256) {
232-
schedule.trySchedule(defaultPartition, expOp);
233-
scheduleDependencies(loop, schedule, defaultPartition, expOp);
229+
for (Operation &op : loop.getOps()) {
230+
if (!isa<math::Exp2Op, ElementwiseInlineAsmOp>(op))
231+
continue;
232+
int elementCount = 0;
233+
for (Type type : op.getResultTypes()) {
234+
if (auto tensorTy = dyn_cast<RankedTensorType>(type))
235+
elementCount += tensorTy.getNumElements();
236+
}
237+
if (elementCount > 256) {
238+
schedule.trySchedule(defaultPartition, &op);
239+
scheduleDependencies(loop, schedule, defaultPartition, &op);
234240
}
235241
}
236242

@@ -242,7 +248,8 @@ static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
242248
while (userPartitions.size() < mmas.size()) {
243249
userPartitions.push_back(schedule.addPartition(userPartitions.size()));
244250
}
245-
for (auto [mmaOp, userPartition] : llvm::zip(mmas, userPartitions)) {
251+
for (auto [mmaOp, userPartition] :
252+
llvm::reverse(llvm::zip(mmas, userPartitions))) {
246253
scheduleUsers(loop, schedule, userPartition, mmaOp);
247254
}
248255

lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ template <class MMAOpTy> class LHSToTMem : public OpRewritePattern<MMAOpTy> {
6969
isDistributedLayoutTMemCompatible(tcGen5MMAOp, srcType, lhsMemDescType);
7070
Attribute newLayout = srcLayout;
7171
if (!layoutTmemCompatible) {
72-
if (triton::tools::getBoolEnv("ALLOW_LHS_TMEM_LAYOUT_CONVERSION")) {
72+
if (!comesFromLoadOrBlockArg(src) ||
73+
triton::tools::getBoolEnv("ALLOW_LHS_TMEM_LAYOUT_CONVERSION")) {
7374
newLayout = getLHSTMemLayout(tcGen5MMAOp, srcType);
7475
} else {
7576
return failure();

python/src/gluon_ir.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
using namespace mlir;
1313
namespace py = pybind11;
14+
namespace tt = triton;
1415
namespace ttg = triton::gpu;
1516
namespace ttng = triton::nvidia_gpu;
1617

@@ -298,7 +299,15 @@ void init_gluon_ir(py::module &&m) {
298299
self.create<ttng::AsyncTMAScatterOp>(descPtr, xOffsets, yOffset,
299300
src);
300301
})
301-
302+
.def("create_broadcast",
303+
[](TritonOpBuilder &self, Value &arg, Type retTy) -> Value {
304+
return self.create<tt::BroadcastOp>(retTy, arg);
305+
})
306+
.def(
307+
"create_expand_dims",
308+
[](TritonOpBuilder &self, Value &arg, int axis, Type retTy) -> Value {
309+
return self.create<tt::ExpandDimsOp>(retTy, arg, axis);
310+
})
302311
.def("create_warp_return",
303312
[](GluonOpBuilder &self) -> Operation * {
304313
return self.create<ttg::WarpReturnOp>();

python/test/gluon/test_frontend.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import expecttest
2+
from triton.runtime.jit import MockTensor
23
import torch
34
import pytest
45
import re
@@ -600,3 +601,183 @@ def kernel():
600601
}
601602
}
602603
""")
604+
605+
606+
@gluon.jit
607+
def broadcast_kernel():
608+
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [2, 16], [4, 1], [1, 0])
609+
a = ttgl.arange(0, 16, layout=ttgl.SliceLayout(0, layout))[None, :]
610+
b = ttgl.arange(0, 16, layout=ttgl.SliceLayout(1, layout))[:, None]
611+
0 + a + b
612+
613+
614+
def test_broadcast(fresh_knobs):
615+
knobs.compilation.disable_line_info = True
616+
617+
h = broadcast_kernel.warmup(sanitize_overflow=False, grid=(1, ))
618+
expecttest.assert_expected_inline(
619+
anonymize_ir(h.asm["source"]), """\
620+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
621+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
622+
tt.func public @broadcast_kernel() attributes {noinline = false} {
623+
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
624+
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> loc(#loc)
625+
%2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc)
626+
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> loc(#loc)
627+
%c0_i32 = arith.constant 0 : i32 loc(#loc)
628+
%c0_i32_0 = arith.constant 0 : i32 loc(#loc)
629+
%cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> loc(#loc)
630+
%4 = arith.addi %cst, %1 : tensor<1x16xi32, #blocked> loc(#loc)
631+
%5 = tt.broadcast %4 : tensor<1x16xi32, #blocked> -> tensor<16x16xi32, #blocked> loc(#loc)
632+
%6 = tt.broadcast %3 : tensor<16x1xi32, #blocked> -> tensor<16x16xi32, #blocked> loc(#loc)
633+
%7 = arith.addi %5, %6 : tensor<16x16xi32, #blocked> loc(#loc)
634+
tt.return loc(#loc)
635+
} loc(#loc)
636+
} loc(#loc)
637+
#loc = loc(unknown)
638+
""")
639+
640+
641+
@gluon.jit
642+
def math_kernel():
643+
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0])
644+
a = ttgl.full([16, 16], 1, ttgl.float32, layout)
645+
b = ttgl.full([16, 16], 2, ttgl.float32, layout)
646+
c = ttgl.full([16, 16], 4, ttgl.float32, layout)
647+
d = ttgl.full([16, 16], 1, ttgl.int32, layout)
648+
e = ttgl.full([16, 16], 1, ttgl.int32, layout)
649+
ttgl.umulhi(d, e)
650+
ttgl.exp(a)
651+
ttgl.exp2(a)
652+
ttgl.log(a)
653+
ttgl.log2(a)
654+
ttgl.cos(a)
655+
ttgl.sin(a)
656+
ttgl.sqrt(a)
657+
ttgl.sqrt_rn(a)
658+
ttgl.rsqrt(a)
659+
ttgl.abs(a)
660+
ttgl.fdiv(a, b)
661+
ttgl.div_rn(a, b)
662+
ttgl.erf(a)
663+
ttgl.floor(a)
664+
ttgl.ceil(a)
665+
ttgl.fma(a, b, c)
666+
667+
668+
def test_math(fresh_knobs):
669+
knobs.compilation.disable_line_info = True
670+
671+
h = math_kernel.warmup(sanitize_overflow=False, grid=(1, ))
672+
expecttest.assert_expected_inline(
673+
anonymize_ir(h.asm["source"]), """\
674+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
675+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
676+
tt.func public @math_kernel() attributes {noinline = false} {
677+
%cst = arith.constant 1.000000e+00 : f32 loc(#loc)
678+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
679+
%cst_1 = arith.constant 2.000000e+00 : f32 loc(#loc)
680+
%cst_2 = arith.constant dense<2.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
681+
%cst_3 = arith.constant 4.000000e+00 : f32 loc(#loc)
682+
%cst_4 = arith.constant dense<4.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
683+
%c1_i32 = arith.constant 1 : i32 loc(#loc)
684+
%cst_5 = arith.constant dense<1> : tensor<16x16xi32, #blocked> loc(#loc)
685+
%c1_i32_6 = arith.constant 1 : i32 loc(#loc)
686+
%cst_7 = arith.constant dense<1> : tensor<16x16xi32, #blocked> loc(#loc)
687+
%0 = tt.mulhiui %cst_5, %cst_7 : tensor<16x16xi32, #blocked> loc(#loc)
688+
%1 = math.exp %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
689+
%2 = math.exp2 %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
690+
%3 = math.log %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
691+
%4 = math.log2 %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
692+
%5 = math.cos %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
693+
%6 = math.sin %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
694+
%7 = math.sqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
695+
%8 = tt.precise_sqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
696+
%9 = math.rsqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
697+
%10 = math.absf %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
698+
%11 = arith.divf %cst_0, %cst_2 : tensor<16x16xf32, #blocked> loc(#loc)
699+
%12 = tt.precise_divf %cst_0, %cst_2 : tensor<16x16xf32, #blocked> loc(#loc)
700+
%13 = math.erf %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
701+
%14 = math.floor %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
702+
%15 = math.ceil %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
703+
%16 = math.fma %cst_0, %cst_2, %cst_4 : tensor<16x16xf32, #blocked> loc(#loc)
704+
tt.return loc(#loc)
705+
} loc(#loc)
706+
} loc(#loc)
707+
#loc = loc(unknown)
708+
""")
709+
710+
711+
@gluon.jit
712+
def pair_add(a0, a1, b0, b1):
713+
return a0 + b0, a1 + b1
714+
715+
716+
@gluon.jit
717+
def reduce_kernel(out):
718+
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0])
719+
a = ttgl.full([16, 16], 1, ttgl.float32, layout)
720+
b = ttgl.full([16, 16], 2, ttgl.float32, layout)
721+
s0 = ttgl.sum(a, 0)
722+
ttgl.static_assert(s0.type.layout == ttgl.SliceLayout(0, layout))
723+
s1 = ttgl.sum(a, 1)
724+
ttgl.static_assert(s1.type.layout == ttgl.SliceLayout(1, layout))
725+
726+
scalar = ttgl.max(s0, 0)
727+
ttgl.static_assert(scalar.type == ttgl.float32)
728+
729+
s1 = ttgl.convert_layout(s1, s0.type.layout)
730+
731+
pairs = ttgl.reduce((a, b), 0, pair_add)
732+
ttgl.static_assert(pairs[0].type.layout == ttgl.SliceLayout(0, layout))
733+
ttgl.static_assert(pairs[1].type.layout == ttgl.SliceLayout(0, layout))
734+
result = scalar + s1 + pairs[0] + pairs[1]
735+
tl.store(out + ttgl.arange(0, 16, s0.type.layout), result)
736+
737+
738+
def test_reduce(fresh_knobs):
739+
knobs.compilation.disable_line_info = True
740+
741+
h = reduce_kernel.warmup(MockTensor(ttgl.float32), sanitize_overflow=False, grid=(1, ))
742+
expecttest.assert_expected_inline(
743+
anonymize_ir(h.asm["ttgir"]), """\
744+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
745+
#loc = loc(unknown)
746+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
747+
tt.func public @reduce_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} {
748+
%cst = arith.constant dense<2.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
749+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
750+
%0 = "tt.reduce"(%cst_0) <{axis = 0 : i32}> ({
751+
^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
752+
%12 = arith.addf %arg1, %arg2 : f32 loc(#loc)
753+
tt.reduce.return %12 : f32 loc(#loc)
754+
}) : (tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
755+
%1 = "tt.reduce"(%cst_0) <{axis = 1 : i32}> ({
756+
^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
757+
%12 = arith.addf %arg1, %arg2 : f32 loc(#loc)
758+
tt.reduce.return %12 : f32 loc(#loc)
759+
}) : (tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc)
760+
%2 = "tt.reduce"(%0) <{axis = 0 : i32}> ({
761+
^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
762+
%12 = arith.maxnumf %arg1, %arg2 : f32 loc(#loc)
763+
tt.reduce.return %12 : f32 loc(#loc)
764+
}) : (tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>) -> f32 loc(#loc)
765+
%3 = ttg.convert_layout %1 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
766+
%4:2 = "tt.reduce"(%cst_0, %cst) <{axis = 0 : i32}> ({
767+
^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown), %arg3: f32 loc(unknown), %arg4: f32 loc(unknown)):
768+
%12 = arith.addf %arg1, %arg3 : f32 loc(#loc)
769+
%13 = arith.addf %arg2, %arg4 : f32 loc(#loc)
770+
tt.reduce.return %12, %13 : f32, f32 loc(#loc)
771+
}) : (tensor<16x16xf32, #blocked>, tensor<16x16xf32, #blocked>) -> (tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>) loc(#loc)
772+
%5 = tt.splat %2 : f32 -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
773+
%6 = arith.addf %5, %3 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
774+
%7 = arith.addf %6, %4#0 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
775+
%8 = arith.addf %7, %4#1 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
776+
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
777+
%10 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
778+
%11 = tt.addptr %10, %9 : tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
779+
tt.store %11, %8 : tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
780+
tt.return loc(#loc)
781+
} loc(#loc)
782+
} loc(#loc)
783+
""")

0 commit comments

Comments
 (0)