Skip to content

Commit 363ee9f

Browse files
Merge OpenAI Triton commit 6fce184 (#5321)
This PR change the Triton base from b5fea1e to 6fce184 (Oct 14). Pass rate: 94.11%
2 parents 16e2a59 + f8b0466 commit 363ee9f

File tree

35 files changed

+2947
-492
lines changed

35 files changed

+2947
-492
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,37 @@ LogicalResult Fp4ToFpOp::verifyFp4ToFp(mlir::Operation *op,
444444
<< ", dst=" << resShape[i] << ", axis=" << axis << ")";
445445
}
446446
}
447+
if (bool(resTy.getEncoding()) != bool(srcTy.getEncoding()))
448+
return op->emitError()
449+
<< "source and result must both have an encoding, or neither";
450+
if (!resTy.getEncoding()) {
451+
return success();
452+
}
453+
auto srcLl = toLinearLayout(srcTy);
454+
auto resLl = toLinearLayout(resTy);
455+
auto *ctx = srcTy.getContext();
456+
auto regDim = StringAttr::get(ctx, "register");
457+
auto outDims = standardOutDimNames(ctx, rank);
458+
459+
// We use backward inference here as it is striclty more general
460+
Attribute inferSrc;
461+
auto dialect =
462+
resTy.getEncoding()
463+
.getDialect()
464+
.getRegisteredInterface<triton::DialectInferLayoutInterface>();
465+
assert(dialect);
466+
if (failed(dialect->inferFp4ToFpOpEncoding(
467+
resTy.getShape(), axis, resTy.getEncoding(), inferSrc,
468+
/*fwdInference*/ false, std::nullopt))) {
469+
return op->emitError() << "failed to infer encoding";
470+
}
471+
if (!areLayoutsEquivalent(srcTy.getShape(),
472+
cast<LayoutEncodingTrait>(inferSrc),
473+
cast<LayoutEncodingTrait>(srcTy.getEncoding())))
474+
return op->emitError()
475+
<< "Src and Dst encodings are not compatible:\n"
476+
<< toLinearLayout(srcTy.getShape(), inferSrc).toString() << "\n"
477+
<< srcLl.toString();
447478
return success();
448479
}
449480

lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,36 @@ static bool willIncreaseRegisterPressure(Operation *op) {
4040
return false;
4141
}
4242

43+
// Return true if it has side effects that are either unknown or writes.
44+
static bool hasWriteSideEffect(Operation *op) {
45+
auto effects = getEffectsRecursively(op);
46+
if (!effects)
47+
return false;
48+
return llvm::any_of(*effects, [](MemoryEffects::EffectInstance effect) {
49+
return !isa<MemoryEffects::Read, MemoryEffects::Allocate,
50+
MemoryEffects::Free>(effect.getEffect());
51+
});
52+
}
53+
54+
// Return true if there is a write side effect on any path between start and end
55+
// ops. This assumes start dominates end.
56+
static bool crossWriteSideEffectingOp(Operation *start, Operation *end) {
57+
auto ancestor = start->getBlock()->findAncestorOpInBlock(*end);
58+
// Couldn't find an ancestor in the same block, conservatively assume true.
59+
if (!ancestor)
60+
return true;
61+
Operation *nextOp = start->getNextNode();
62+
while (nextOp) {
63+
if ((hasWriteSideEffect(nextOp)))
64+
return true;
65+
if (nextOp == ancestor)
66+
return false;
67+
nextOp = nextOp->getNextNode();
68+
}
69+
assert(false && "op doesn't dominate other");
70+
return true;
71+
}
72+
4373
class TritonGPUReorderInstructionsPass
4474
: public impl::TritonGPUReorderInstructionsBase<
4575
TritonGPUReorderInstructionsPass> {
@@ -135,6 +165,8 @@ class TritonGPUReorderInstructionsPass
135165
// after the conversion to OpIdx=0.
136166
if (!dom.dominates(op.getOperation(), AOp.getOperation()))
137167
return;
168+
if (crossWriteSideEffectingOp(op, AOp))
169+
return;
138170
moveAfter(op, AOp);
139171
});
140172
return;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ static Attribute inferDstEncoding(triton::gpu::Fp4ToFpOp op, Attribute srcEnc) {
431431

432432
static Attribute inferSrcEncoding(triton::gpu::Fp4ToFpOp op, Attribute dstEnc) {
433433
Attribute srcEnc;
434-
auto shape = op.getSrc().getType().getShape();
434+
auto shape = op.getType().getShape();
435435
if (succeeded(
436436
dstEnc.getDialect()
437437
.getRegisteredInterface<triton::DialectInferLayoutInterface>()

python/src/llvm.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ std::unique_ptr<TargetMachine>
4747
createTargetMachine(llvm::Module *module, std::string proc,
4848
bool enable_fp_fusion, const std::string &features) {
4949
std::string error;
50-
auto target = llvm::TargetRegistry::lookupTarget(
51-
module->getTargetTriple().str(), error);
50+
auto target =
51+
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
5252
llvm::TargetOptions opt;
5353
bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT");
5454
if (enable_fp_fusion)
@@ -278,15 +278,16 @@ void init_triton_llvm(py::module &&m) {
278278
const std::string proc,
279279
const std::string features) {
280280
std::string error;
281-
auto target = llvm::TargetRegistry::lookupTarget(triple, error);
281+
llvm::Triple targetTriple(triple);
282+
auto target = llvm::TargetRegistry::lookupTarget(targetTriple, error);
282283
if (!target) {
283284
throw std::runtime_error("target lookup error: " + error);
284285
}
285286
llvm::TargetOptions opt;
286287
// Target machine is only used to create the data layout.
287288
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
288-
llvm::Triple(triple), proc, features, opt, llvm::Reloc::PIC_,
289-
std::nullopt, llvm::CodeGenOptLevel::None)};
289+
targetTriple, proc, features, opt, llvm::Reloc::PIC_, std::nullopt,
290+
llvm::CodeGenOptLevel::None)};
290291
// set data layout
291292
mod->setDataLayout(machine->createDataLayout());
292293
});

python/test/gluon/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1346,7 +1346,7 @@ def fp8e8m0_to_float32(scale):
13461346
return scale
13471347

13481348

1349-
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
1349+
@pytest.mark.xfail(not is_blackwell(), reason="Requires Blackwell", run=False)
13501350
def test_tcgen05_mma_scaled_minimal():
13511351
M = 128
13521352
N = 128

python/test/gluon/test_frontend.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,6 +2445,29 @@ def kernel():
24452445
""")
24462446

24472447

2448+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
2449+
def test_amd_mfma_scaled_none(target):
2450+
2451+
@gluon.jit
2452+
def kernel():
2453+
mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(4, [16, 16, 128], True, [1, 1])
2454+
scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout([],
2455+
[[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]],
2456+
[], [], [16, 4])
2457+
2458+
a = ttgl.full([16, 64], 0x11, ttgl.uint8, ttgl.DotOperandLayout(0, mfma_layout, 16))
2459+
b = ttgl.full([64, 16], 0x22, ttgl.uint8, ttgl.DotOperandLayout(1, mfma_layout, 16))
2460+
2461+
b_scale = ttgl.full([16, 4], 0x01, ttgl.uint8, scale_layout)
2462+
acc = ttgl.full([16, 16], 0, ttgl.float32, mfma_layout)
2463+
ttgl.amd.cdna4.mfma_scaled(a, None, 'e2m1', b, b_scale, 'e2m1', acc)
2464+
2465+
with pytest.raises(CompilationError) as e:
2466+
run_parser(kernel, target=target)
2467+
2468+
assert "Scales must not be None" in str(e.value)
2469+
2470+
24482471
@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
24492472
def test_amd_wmma_scaled(target):
24502473

@@ -2497,6 +2520,32 @@ def kernel():
24972520
""")
24982521

24992522

2523+
@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
2524+
def test_amd_wmma_scaled_none(target):
2525+
2526+
@gluon.jit
2527+
def kernel():
2528+
wmma_layout: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [1, 1], [16, 16, 128])
2529+
wmma_layout_packed: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [1, 1], [16, 16, 64])
2530+
scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout([[0, 1], [0, 2]],
2531+
[[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], [], [],
2532+
[16, 4])
2533+
a_layout: ttgl.constexpr = ttgl.DotOperandLayout(0, wmma_layout_packed, 16)
2534+
b_layout: ttgl.constexpr = ttgl.DotOperandLayout(1, wmma_layout_packed, 16)
2535+
2536+
a = ttgl.full([16, 64], 0x11, ttgl.uint8, a_layout)
2537+
b = ttgl.full([64, 16], 0x22, ttgl.uint8, b_layout)
2538+
b_scale = ttgl.full([16, 4], 0x01, ttgl.uint8, scale_layout)
2539+
acc = ttgl.full([16, 16], 0, ttgl.float32, wmma_layout)
2540+
2541+
ttgl.amd.gfx1250.wmma_scaled(a, None, 'e2m1', b, b_scale, 'e2m1', acc)
2542+
2543+
with pytest.raises(CompilationError) as e:
2544+
run_parser(kernel, target=target)
2545+
2546+
assert "Scales must not be None" in str(e.value)
2547+
2548+
25002549
@gluon.jit
25012550
def padded_shared_layout_kernel():
25022551
shape: ttgl.constexpr = [64, 64]

0 commit comments

Comments
 (0)