From 85400f80bf859a34ad7a746ffda877faf80312ab Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 19 Nov 2025 16:40:46 -0800 Subject: [PATCH 01/17] [BACKEND] run remove backward prop until a fix point (#8776) --- .../Transforms/RemoveLayoutConversions.cpp | 41 +++++++++++-------- test/TritonGPU/combine.mlir | 7 ++-- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index efe0b890dc..3531c0bf6d 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -127,7 +127,7 @@ class LayoutRematerialization { } void cleanup(); - void backwardRematerialization(); + bool backwardRematerialization(); void backwardRematerialization(ConvertLayoutOp convertOp); // TODO: Merge the three hoistConvert*(); functions as they are duplicate code void hoistConvertDotOperand(); @@ -1019,7 +1019,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice( return success(); } -void LayoutRematerialization::backwardRematerialization() { +bool LayoutRematerialization::backwardRematerialization() { + bool changed = false; // Go through each ConvertLayoutOp. SmallVector convertOps; funcOp.walk( @@ -1031,8 +1032,11 @@ void LayoutRematerialization::backwardRematerialization() { // backward slices. addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), convertOp.getResult()); + } else { + changed = true; } } + return changed; } void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { @@ -1593,12 +1597,14 @@ void LayoutRematerialization::hoistConvertIntoConditionals( rewriteSlice(slice, layout, convertOp, mapping); } -void backwardRematerialization(ModuleOp module) { - module.walk([](FuncOp funcOp) { +bool backwardRematerialization(ModuleOp module) { + bool changed = false; + module.walk([&](FuncOp funcOp) { LayoutRematerialization layoutRemat(funcOp); - layoutRemat.backwardRematerialization(); + changed |= layoutRemat.backwardRematerialization(); layoutRemat.cleanup(); }); + return changed; } void hoistConvert(ModuleOp module) { @@ -1659,17 +1665,20 @@ class TritonGPURemoveLayoutConversionsPass cleanupConvertOps(); - // 2. For remaining convert ops, try to rematerialize the slice of producer - // operation to avoid having to convert. - backwardRematerialization(m); - LLVM_DEBUG({ - DBGS() << "Module after backward remat:\n"; - m.dump(); - }); - - // Cleanup dummy converts created during backward remat. - cleanupConvertOps(); - + bool changed = false; + do { + changed = false; + // 2. For remaining convert ops, try to rematerialize the slice of + // producer operation to avoid having to convert. + changed = backwardRematerialization(m); + LLVM_DEBUG({ + DBGS() << "Module after backward remat:\n"; + m.dump(); + }); + + // Cleanup dummy converts created during backward remat. + cleanupConvertOps(); + } while (changed); // 3. For remaining converts, try to hoist them above cast generating larger // size types in order to reduce the cost of the convert op. hoistConvert(m); diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 997354685f..5421fa8d19 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2500,11 +2500,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %2 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> %3 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> - // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) - // FIXME: The optimal number of conversions should be 4. - // CHECK-COUNT-5: convert_layout + // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-COUNT-4: convert_layout // CHECK-NOT: convert_layout - // CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + // CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> // CHECK: } // CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { From 80104b77f0d9ab0df9d3273ca68641e85ecc4117 Mon Sep 17 00:00:00 2001 From: mengfei-jiang Date: Thu, 20 Nov 2025 16:13:18 +0800 Subject: [PATCH 02/17] [AMD] Preserve Denorms for precise sqrt (#8697) This commit modifies the denorm behavior for precise sqrt: switching from FTZ (Flush To Zero) to denorm preservation. --- test/Conversion/amd/math-denorm-handling.mlir | 30 ++++---- .../ElementwiseOpToLLVM.cpp | 71 +------------------ 2 files changed, 16 insertions(+), 85 deletions(-) diff --git a/test/Conversion/amd/math-denorm-handling.mlir b/test/Conversion/amd/math-denorm-handling.mlir index c3ab9df370..a09cf197cb 100644 --- a/test/Conversion/amd/math-denorm-handling.mlir +++ b/test/Conversion/amd/math-denorm-handling.mlir @@ -64,22 +64,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @test_sqrt_rn_f32(%arg0: tensor<64xf32, #blocked>) { - // LLVM_FTZ-LABEL: test_sqrt_rn_f32 - // LLVM_FTZ: llvm.amdgcn.rsq.f32 - // LLVM_FTZ: llvm.fmul - // LLVM_FTZ: llvm.fmul - // LLVM_FTZ: llvm.fneg - // LLVM_FTZ: llvm.intr.fma - // LLVM_FTZ-NEXT: llvm.intr.fma - // LLVM_FTZ-NEXT: llvm.intr.fma - // LLVM_FTZ-NEXT: llvm.fneg - // LLVM_FTZ-NEXT: llvm.intr.fma - // LLVM_FTZ-NEXT: llvm.intr.fma - // LLVM_FTZ-NEXT: llvm.intr.is.fpclass - // LLVM_FTZ-NEXT: llvm.select - // - // LLVM_NO_FTZ-LABEL: test_sqrt_rn_f32 - // LLVM_NO_FTZ: llvm.intr.sqrt + // COMMON-LABEL: test_sqrt_rn_f32 + // COMMON: llvm.intr.sqrt %0 = tt.precise_sqrt %arg0 : tensor<64xf32, #blocked> tt.return } @@ -96,3 +82,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @test_divf_rn_f32(%arg0: tensor<64xf32, #blocked>, %arg1: tensor<64xf32, #blocked>) { + // COMMON-LABEL: test_divf_rn_f32 + // COMMON: llvm.fdiv + %0 = tt.precise_divf %arg0, %arg1 : tensor<64xf32, #blocked> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 2dddeb898c..5d5165796c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -2264,73 +2264,6 @@ struct SqrtOpConversion } } -private: - bool ftz; -}; - -struct PreciseSqrtOpConversion - : ElementwiseOpConversionBase { - explicit PreciseSqrtOpConversion(LLVMTypeConverter &typeConverter, - ModuleAxisInfoAnalysis &axisInfoAnalysis, - bool ftz, PatternBenefit benefit) - : ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit), - ftz(ftz) {} - - SmallVector createDestOps(triton::PreciseSqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - auto b = TritonLLVMOpBuilder(loc, rewriter); - // If the op is neither FP32 nor denorm flushing(ftz), it's directly lowered - // to LLVM::SqrtOp. - if (elemTy.getIntOrFloatBitWidth() != 32 || !ftz) { - return {LLVM::SqrtOp::create(rewriter, loc, elemTy, operands[0], - adaptor.getAttributes().getValue())}; - } - - // On the AMDGPU backend, instructions legalized from LLVM::SqrtOp are - // designed to always preserve denorms, according to - // https://github.com/llvm/llvm-project/blob/3d6b2d49/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp#L5235-L5314. - // - // For f32 inputs with ftz enabled, we need to manually lower the op to - // bypass the scaling-up-and-down process while keeping other parts - // unchanged. To ensure IEEE-compliant results, we approximate `sqrt(x)` - // using `x * rsq(x)` and apply extra refinement iterations to correct the - // result. - StringRef funcName = "llvm.amdgcn.rsq.f32"; - - Type funcType = getFunctionType(elemTy, operands[0]); - LLVM::LLVMFuncOp funcOp = - appendOrGetExternFuncOp(rewriter, op, funcName, funcType); - - Value sqrtR = - LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult(); - - Value sqrtX = operands[0][0]; - Value sqrtS = b.fmul(f32_ty, sqrtX, sqrtR); - - // Refine the approximation with Newton iteration - Value sqrtH = b.fmul(f32_ty, sqrtR, b.f32_val(0.5f)); - Value sqrtE = b.fma(b.neg(f32_ty, sqrtH), sqrtS, b.f32_val(0.5f)); - sqrtH = b.fma(sqrtH, sqrtE, sqrtH); - sqrtS = b.fma(sqrtS, sqrtE, sqrtS); - Value sqrtD = b.fma(b.neg(f32_ty, sqrtS), sqrtS, sqrtX); - sqrtS = b.fma(sqrtD, sqrtH, sqrtS); - - // Handle +0/-0/+inf - // These flags come from - // https://github.com/llvm/llvm-project/blob/217e0f39/llvm/include/llvm/ADT/FloatingPointMode.h#L239-L265. - const unsigned fcPosInf = 0x0200; - const unsigned fcNegZero = 0x0020; - const unsigned fcPosZero = 0x0040; - const unsigned fcZero = fcNegZero | fcPosZero; - - Value isZeroOrPosInf = - LLVM::IsFPClass::create(rewriter, loc, i1_ty, sqrtX, fcPosInf | fcZero); - return {b.select(isZeroOrPosInf, sqrtX, sqrtS)}; - } - private: bool ftz; }; @@ -2382,6 +2315,8 @@ void populateElementwiseOpToLLVMPatterns( typeConverter, axisInfoAnalysis, benefit); patterns.add>( typeConverter, axisInfoAnalysis, benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); @@ -2409,8 +2344,6 @@ void populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, axisInfoAnalysis, ftz, benefit); patterns.add(typeConverter, axisInfoAnalysis, ftz, benefit); - patterns.add(typeConverter, axisInfoAnalysis, ftz, - benefit); triton::populateElementwiseOpToLLVMPatterns( typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); bool hwNanPropagationSupported = targetInfo.supportMaximumMinimum(); From feedad8ad70b78616bb2fbe3813471c4e4359df0 Mon Sep 17 00:00:00 2001 From: Liyang Ling Date: Fri, 21 Nov 2025 00:23:35 +0800 Subject: [PATCH 03/17] [AMD] Fix `lowerLoops`: only erase load ops which are converted (#8737) This change addresses the issue that when there is a LoadOp and AddfOp between 2 dots in a loop, this LoadOp is not streamable in AMDGPUPipeline Pass. This case would make compile crash for erasing LoadOp which still have uses. The solution is to replace `loadToInfo` with `loadToStreamOps`, so that only erase LoadOps that are converted to Stream Ops. --- .../amd/amd-pipeline-chained-dots.mlir | 54 +++++++++++++++++++ .../lib/TritonAMDGPUTransforms/LowerLoops.cpp | 2 +- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir b/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir index 9a967924e4..01e92768d9 100644 --- a/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir +++ b/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir @@ -160,3 +160,57 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return %6 : tensor<128x16xf32, #mma> } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [8, 1], instrShape = [16, 16, 16], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @chained_dots_with_load_bias_in_between + + // Similar to the previous test but load bias tensor bewteen 2 dots + // We expect the unstreamable load can be kept after pipelining + + // CHECK: scf.for + // CHECK: tt.dot + // CHECK: ttg.async_copy_global_to_local + // CHECK: tt.dot + // CHECK: ttg.async_wait + // CHECK: ttg.local_load + // CHECK: tt.load + // CHECK: scf.yield + + tt.func @chained_dots_with_load_bias_in_between(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg2: i64 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32) -> tensor<256x64xf32, #mma> { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %3 = tt.broadcast %1 : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %4 = tt.addptr %2, %3 : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %6 = tt.splat %arg3 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> + %7 = scf.for %arg5 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg6 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { + %8 = tt.load %4 : tensor<64x64x!tt.ptr, #blocked> + %9 = ttg.convert_layout %8 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %10 = tt.dot %arg1, %9, %cst : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x64xf32, #mma> + %11 = arith.muli %arg5, %c64_i32 : i32 + %12 = tt.splat %11 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %13 = arith.addi %12, %5 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %15 = tt.broadcast %14 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> + %bias_ptr = tt.addptr %6, %15 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> + %bias = tt.load %bias_ptr : tensor<256x64x!tt.ptr, #blocked> + %bias_mma = ttg.convert_layout %bias : tensor<256x64xf16, #blocked> -> tensor<256x64xf16, #mma> + %bias_f32 = arith.extf %bias_mma : tensor<256x64xf16, #mma> to tensor<256x64xf32, #mma> + %dot_bias = arith.addf %10, %bias_f32 : tensor<256x64xf32, #mma> + %21 = arith.truncf %dot_bias : tensor<256x64xf32, #mma> to tensor<256x64xf16, #mma> + %22 = ttg.convert_layout %21 : tensor<256x64xf16, #mma> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %23 = tt.dot %22, %9, %arg6 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x64xf32, #mma> + scf.yield %23 : tensor<256x64xf32, #mma> + } + tt.return %7 : tensor<256x64xf32, #mma> + } +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp index ea54600c24..2b0966c16f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp @@ -718,7 +718,7 @@ void updateSchedule(scf::ForOp &forOp, const LoadToInfoMap &loadToInfo, useAsyncCopy, axisInfoAnalysis); scheduleStreamOps(loadToStreamOps, schedule, clusters); - for (auto [l, _] : loadToInfo) { + for (auto [l, _] : loadToStreamOps) { schedule.erase(l); l->erase(); } From fe7838c57fd13becf4e4c7dc5191864967959874 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Thu, 20 Nov 2025 10:26:18 -0800 Subject: [PATCH 04/17] [consan] Handle all tmem allocations (#8787) --- lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp index 3584878121..75ef836873 100644 --- a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp +++ b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp @@ -27,7 +27,7 @@ Value createMemDescToI64(RewriterBase &rewriter, Location loc, const LLVMTypeConverter *typeConverter, ttg::MemDescType memDescTy, Value sharedMemStruct) { TritonLLVMOpBuilder b(loc, rewriter); - if (isa(memDescTy.getEncoding())) { + if (isa(memDescTy.getMemorySpace())) { return b.ptrtoint(rewriter.getIntegerType(64), sharedMemStruct); } assert(isa(memDescTy.getEncoding()) && From 2e1a036c220d007b0e356534bea8b2795ee4d153 Mon Sep 17 00:00:00 2001 From: Saeid Rostami <123997133+saeid-rostami@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:31:03 -0500 Subject: [PATCH 05/17] [AMD] Enabling Buffer Atomic for RDNA4 (#8778) This PR enables buffer atomic on RDNA4 for supported data types. --- .../amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index 7afac143a8..0c59a3d71a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -361,6 +361,10 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW atomicRmwOp == RMWOp::FADD) { return rewriter.notifyMatchFailure(op, "RMW FADD does not support bf16"); } + if (isaFamily == ISAFamily::RDNA4 && checkType.isF64() && + atomicRmwOp == RMWOp::FADD) { + return rewriter.notifyMatchFailure(op, "RMW FADD does not support F64"); + } LDBG("RMW FADD supported 16-bit type"); auto vecSize = getVectorSize(ptr, axisAnalysisPass); @@ -624,7 +628,8 @@ struct TritonAMDGPUConvertToBufferOpsPass triton::AMD::ISAFamily isaFamily = triton::AMD::deduceISAFamily(archGenerationName); if (this->allowBufferAtomics && - (ISAFamily::CDNA3 == isaFamily || ISAFamily::CDNA4 == isaFamily)) + (ISAFamily::CDNA3 == isaFamily || ISAFamily::CDNA4 == isaFamily || + ISAFamily::RDNA4 == isaFamily)) patterns.add( context, assumptions, axisInfoAnalysis, solver, isaFamily, this->analyzeSmallTensorOfst); From ecfaec21cdf89783025dd60d11f8c5a593c8b93f Mon Sep 17 00:00:00 2001 From: Danial Javady <122740063+ZelboK@users.noreply.github.com> Date: Thu, 20 Nov 2025 15:57:31 -0500 Subject: [PATCH 06/17] [PROTON][AMD] Fix failing proton tests for AMD GPUs (#8763) Fixes upgrade to rocm7 breaking proton tests alongside implementing CircularStoreOp for gmem # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [ ] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: danial javady --- .../AMDPatternProtonGPUOpToLLVM.cpp | 5 ++++- .../Profiler/RocTracer/RoctracerProfiler.cpp | 19 +++++++++++-------- .../proton/test/test_instrumentation.py | 3 --- third_party/proton/test/test_profile.py | 1 - 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp b/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp index f314e335a5..11a21f9f50 100644 --- a/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp +++ b/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp @@ -2,8 +2,10 @@ #include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.h" #include "Conversion/ProtonGPUToLLVM/Utility.h" #include "Dialect/ProtonGPU/IR/Dialect.h" +#include "amd/lib/TritonAMDGPUToLLVM/Utility.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/PatternMatch.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -37,7 +39,8 @@ struct CircularStoreOpConversion // TODO(crobeck): see what buffer ops performance looks like here for // global mem (address space 1) compared to predicated ops to shared // memory - llvm::report_fatal_error("unimplemented"); + mlir::LLVM::AMD::llStore(rewriter, loc, dataPack.ptr, dataPack.record, + dataPack.isWriter); } else if (addrSpace == 3) { targetInfo.getTritonTargetInfo().storeDShared( rewriter, loc, dataPack.ptr, std::nullopt, dataPack.record, diff --git a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp index d43fd2b28d..79f55cd938 100644 --- a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp @@ -353,19 +353,22 @@ void RoctracerProfiler::RoctracerProfilerPimpl::activityCallback( // data on stop maxCorrelationId = std::max(maxCorrelationId, record->correlation_id); - // TODO(Keren): Roctracer doesn't support cuda graph yet. + bool hasCorrelation = + correlation.corrIdToExternId.contain(record->correlation_id); auto externId = - correlation.corrIdToExternId.contain(record->correlation_id) + hasCorrelation ? correlation.corrIdToExternId.at(record->correlation_id).first : Scope::DummyScopeId; auto isAPI = correlation.apiExternIds.contain(externId); bool isGraph = pImpl->CorrIdToIsHipGraph.contain(record->correlation_id); - processActivity(correlation.corrIdToExternId, correlation.apiExternIds, - externId, dataSet, record, isAPI, isGraph); - // Track correlation ids from the same stream and erase those < - // correlationId - correlation.corrIdToExternId.erase(record->correlation_id); - correlation.apiExternIds.erase(externId); + if (hasCorrelation) { + processActivity(correlation.corrIdToExternId, correlation.apiExternIds, + externId, dataSet, record, isAPI, isGraph); + // Track correlation ids from the same stream and erase those < + // correlationId + } else { + correlation.apiExternIds.erase(externId); + } roctracer::getNextRecord(record, &record); } correlation.complete(maxCorrelationId); diff --git a/third_party/proton/test/test_instrumentation.py b/third_party/proton/test/test_instrumentation.py index 388a587baf..271b0ff835 100644 --- a/third_party/proton/test/test_instrumentation.py +++ b/third_party/proton/test/test_instrumentation.py @@ -15,7 +15,6 @@ is_cuda, is_hip, is_hip_cdna2, - is_hip_cdna4, supports_tma, supports_ws, ) @@ -644,7 +643,6 @@ def foo(x, y, size: tl.constexpr): assert trace_events[-1]["args"]["call_stack"][-2] == "test" -@pytest.mark.skipif(is_hip_cdna4(), reason="nondeterministic failure") def test_globaltime(tmp_path: pathlib.Path): temp_file = tmp_path / "test_globaltime.chrome_trace" mode = proton.mode.Default( @@ -760,7 +758,6 @@ def session_kernel_time(session_name: str) -> Tuple[int, int]: assert session1_loop_time / session0_loop_time < loop_threshold, "Loop kernel overhead too high" -@pytest.mark.skipif(is_hip(), reason="not implemented yet") def test_gmem_buffer(tmp_path: pathlib.Path): @triton.jit diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index ad8124205d..a5f20a214b 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -75,7 +75,6 @@ def foo(x, y): assert data[0]["children"][1]["frame"]["name"] == "test2" -@pytest.mark.skipif(is_hip(), reason="Currently broken after updating to ROCm 7") def test_cudagraph(tmp_path: pathlib.Path): stream = torch.cuda.Stream() torch.cuda.set_stream(stream) From 89a3c6e5115e11110a8cfb245a0e6ab45e1bea74 Mon Sep 17 00:00:00 2001 From: Pengzhan Zhao Date: Thu, 20 Nov 2025 13:32:58 -0800 Subject: [PATCH 07/17] [AMD] Make kWidth to mandatory for WMMA v3 (#8783) Currently we limit WMMA v3's kWidth to be {2, 8, 16} which matches the hardware view for all possible WMMA instructions. In the case of wmma_scaled, we assume kWidth always to be 16. But in attention kernel, we can use kWidth = 8 which will remove the layout convert between 2 dots. This does not match the hardware view for continuous elements from k dimension, but we can still get correct results unless the kWidth for 2 operands are the same. This PR removes the kWidth check for WMMA v3 and makes it mandatory, same as MFMA. --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 4 +-- .../tritongpu_wmma_dot_scaled_to_llvm.mlir | 26 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index e7fe5b7e2c..bbd5b5d58f 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2501,9 +2501,9 @@ LogicalResult DotOperandEncodingAttr::verify( return emitError() << "ttg.dot_op kWidth parameter must be 4/8/16 for WMMA v2 " "(including packed cases for `scaled_dot`)"; - if (parentAttr.getVersion() == 3 && !llvm::is_contained({2, 8, 16}, kWidth)) + if (parentAttr.getVersion() == 3 && kWidth == 0) return emitError() - << "ttg.dot_op kWidth parameter must be 2/8/16 for WMMA v3"; + << "ttg.dot_op kWidth parameter is mandatory for WMMA v3 "; return success(); } diff --git a/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir index 6337fec57e..9f67f5cb66 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir @@ -200,3 +200,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [32, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [16, 0], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [0, 0]], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [4, 1], instrShape=[16, 16, 128]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_scaled_dot_fp8_chained + tt.func @wmma_scaled_dot_fp8_chained(%arg0: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg2: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %scale0 = arith.constant dense<127> : tensor<128x4xi8, #linear> + %scale1 = arith.constant dense<127> : tensor<128x4xi8, #linear1> + // CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %mm0 = tt.dot_scaled %arg0 scale %scale0, %arg2 scale %scale1, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<128x4xi8, #linear> * tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<128x4xi8, #linear1> -> tensor<128x128xf32, #mma> + // CHECK-NOT: rocdl.ds_swizzle + // CHECK-NOT: llvm.call_intrinsic "llvm.amdgcn.permlane16.swap" + %op0 = ttg.convert_layout %mm0 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %op1 = tt.fp_to_fp %op0, rounding = rtne : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> -> tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + // CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %mm1 = tt.dot_scaled %op1 scale %scale0, %arg3 scale %scale1, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>, tensor<128x4xi8, #linear> * tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, tensor<128x4xi8, #linear1> -> tensor<128x128xf32, #mma> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<128x128x!tt.ptr, #mma> + tt.store %ptr0, %mm1 : tensor<128x128x!tt.ptr, #mma> + tt.return + } +} From db14c2d417716e1fff094ec58606d217eacde50b Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Thu, 20 Nov 2025 21:49:54 +0000 Subject: [PATCH 08/17] [AMD] Allow async load global to load block dimension duplication (#8788) Broadcasts in the `block` dimensions are not redundant so we should not mask them. This way each CTA has their own copy in shared memory, note that the multicast mask will be set in such cases to efficiently load the data. --- test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir | 3 --- .../amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 9 +++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir b/test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir index ec5a5fb418..690ff63cbb 100644 --- a/test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir +++ b/test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir @@ -81,7 +81,6 @@ module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK-LABEL: async_load_multicast_to_half_ctas tt.func public @async_load_multicast_to_half_ctas(%arg0: tensor<32x32x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}, %arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) { - // CHECK: llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-7 : i32) : i32 // CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]] @@ -104,7 +103,6 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.sha tt.func public @async_load_multicast_group_of_2_strided_by_8(%arg0: tensor<32x32x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}, %arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) { // Skip the first cluster id because it's emitted for address calculation - // CHECK: llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32 // CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]] @@ -146,7 +144,6 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.sha tt.func public @async_load_multi_cta_linear_layout(%arg0: tensor<32x32x!tt.ptr, #linear> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}, %arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) { // Skip the first cluster id because it's emitted for address calculation - // CHECK: llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32 // CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]] diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 3d1d50ba9d..2d7c4256f4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1034,8 +1034,13 @@ struct AsyncCopyGlobalToLocalOpConversion zipLoadValues(rewriter, loc, vec, srcElems, srcPtrTy, maskElements, otherElems, otherTy, swizzledLaneOffsets); - Value threadPred = emitRedundantThreadPredicate(getFreeVariableMasks(srcTy), - rewriter, loc, targetInfo); + auto freeVarMasks = getFreeVariableMasks(srcTy); + // We load redundant data on different CTAs so each CTA has a copy in its + // shared memory; the multicast mask will be used by the hardware to + // efficiently broadcast to different CTAs. + freeVarMasks[rewriter.getStringAttr("block")] = 0; + Value threadPred = + emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo); auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); auto emitGlobalLoadLds = From 31281bcdd769d79519887cd6ba31570e1e2f6ee2 Mon Sep 17 00:00:00 2001 From: neildhar Date: Thu, 20 Nov 2025 15:10:02 -0800 Subject: [PATCH 09/17] [Reland] Fix handling of unvisited operands in AxisInfoAnalysis (#8758) We currently force initialisation of operands that have not yet been visited with `setToEntryState`. This means that the order in which values are visited can change the results of the analysis. This can be a source of bugs. For example, the lowering for `AsyncCopyGlobalToLocalOp` validates that the load addresses permit sufficient vectorisation, however, this is up to the analysis actually recovering the same information it had when the async copy was created. Otherwise, we crash during lowering. I have an actual repro for this but it has been very difficult to minimise it enough to make it suitable for an lit test: https://gist.github.com/neildhar/7eea6a312afa39d1cc83dc12627c2ba3 Populating the operands in this way also means that we have to handle control flow like `ForOp` and `IfOp` explicitly in `setToEntryState`, because we may attempt to populate their results when we visit their users. Instead, when we encounter an operation whose operands have not yet been encountered, skip over the operation entirely. We can revisit it once the operands have actually been visited. This improves the quality of the analysis, and leaves the handling of control flow to the dataflow framework. This reland adds handling for the case where the dataflow analysis fails to initialise a particular value (likely because it is determined to be dead). --- lib/Analysis/AxisInfo.cpp | 34 ++++++++++++++----------------- test/Analysis/test-alignment.mlir | 15 ++++++++++++++ 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 159392f174..50ded51aa3 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1075,11 +1075,10 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver, LogicalResult AxisInfoAnalysis::visitOperation( Operation *op, ArrayRef *> operands, ArrayRef *> results) { - // TODO: For sure not the right way to do this - // but why is scf.if not initialized otherwise? + // If any operands are not yet ready, skip this operation for now. for (auto op : operands) if (op->getValue().getRank() == 0) - setToEntryState((dataflow::Lattice *)op); + return success(); AxisInfo curr = visitors.apply(op, operands); if (curr.getRank() == 0) { setAllToEntryStates(results); @@ -1108,9 +1107,11 @@ void AxisInfoAnalysis::visitForOpInductionVar( ProgramPoint *programPoint = getProgramPointAfter(op); auto *lbLattice = getLatticeElementFor(programPoint, op.getLowerBound()); auto *stepLattice = getLatticeElementFor(programPoint, op.getStep()); - for (auto op_iter : {lbLattice, stepLattice}) - if (op_iter->getValue().getRank() == 0) - setToEntryState((dataflow::Lattice *)op_iter); + // If lb or step is not yet ready, skip this operation for now. + if (lbLattice->getValue().getRank() == 0 || + stepLattice->getValue().getRank() == 0) { + return; + } AxisInfo::DimVectorT knownContiguity(1, 1); AxisInfo::DimVectorT knownDivisibility(1, 1); @@ -1184,24 +1185,15 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) { initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, &knownContiguity, &knownDivisibility, &knownConstancy); - } else if (isa( - op)) { - // scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp - // Control flow operations are initialized with "unknown" state: - // the maximum possible divisibility, contiguity, and constancy. + } else if (isa(op)) { + // Initialize the arguments to gpu::WarpSpecializePartitionsOp with + // "unknown" state: the maximum possible divisibility, contiguity, and + // constancy. knownDivisibility = DimVectorT(rank, kMaxDivisor); knownConstancy = DimVectorT(rank, kMaxDivisor); knownContiguity = DimVectorT(rank, kMaxDivisor); } } else if (Operation *op = value.getDefiningOp()) { - if (isa(op)) { - // scf::ForOp, scf::IfOp, scf::WhileOp - // Control flow operations are initialized with "unknown" state: - // the maximum possible divisibility, contiguity, and constancy. - knownDivisibility = DimVectorT(rank, kMaxDivisor); - knownConstancy = DimVectorT(rank, kMaxDivisor); - knownContiguity = DimVectorT(rank, kMaxDivisor); - } // Other operations are conservatively initialized with the lowest possible // divisibility, contiguity, and constancy unless they have specified. AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"), @@ -1354,6 +1346,10 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp, auto *axisInfoMap = getFuncData(funcOp); auto updateAxisInfoMap = [&](Value value) { auto axisInfo = analysis->getLatticeElement(value)->getValue(); + // If we could not determine the AxisInfo for this value, assume the + // pessimistic state. + if (axisInfo.getRank() == 0) + axisInfo = AxisInfo::getPessimisticValueState(value); AxisInfo curAxisInfo; if (axisInfoMap->count(value)) { curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value)); diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 49821e969f..7f17dca4f9 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -1089,3 +1089,18 @@ tt.func public @test_inductor_for() { } tt.return } + +// ----- + +// Verify that if an operation is statically determined to be dead, we fall back +// to assigning it a pessimistic value, rather than skipping it entirely. +tt.func @dead_op_pessimistic() { + %c5 = arith.constant dense<5> : tensor<4xi32> + %c7 = arith.constant dense<7> : tensor<4xi32> + %false = arith.constant false + scf.if %false { + // expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %add = arith.addi %c5, %c7 : tensor<4xi32> + } + tt.return +} From 6294db5a12443a49d1f0604a8de08d2b4b921497 Mon Sep 17 00:00:00 2001 From: aeng-openai Date: Thu, 20 Nov 2025 18:04:48 -0800 Subject: [PATCH 10/17] [KERNELS] fix persistent matmul heuristics (#8791) any mxfp where natively supported requires using the persistent matmul kernel. in these cases, do not use heuristics to resolve `is_persistent` --- .../triton_kernels/matmul_ogs_details/opt_flags.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 53a79c3f5b..8aa185b0ae 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -3,11 +3,12 @@ from dataclasses import dataclass import triton +from triton_kernels import target_info from triton_kernels.target_info import get_cdna_version from triton_kernels.tensor import FP4 import torch from .opt_flags_details import opt_flags_amd, opt_flags_nvidia -from triton_kernels.tensor import bitwidth +from triton_kernels.tensor import bitwidth, get_layout @dataclass @@ -215,8 +216,12 @@ def make_default_opt_flags_nvidia( n_sms = torch.cuda.get_device_properties(0).multi_processor_count tiles_per_sm = grid_size_tma / n_sms supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9) + requires_persistent = (get_layout(precision_config.act_scale) is not None or get_layout(precision_config.weight_scale) is not None) and target_info.has_native_mxfp() if constraints.get("is_persistent", None) is not None: is_persistent = constraints["is_persistent"] + elif requires_persistent: + assert supports_persistent, "persistent kernel required but not supported" + is_persistent = True else: has_simple_epilogue = precision_config.max_num_imprecise_acc is None is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4 From 4823a6e711d1ed1d85dc58dc1aaba13efe735497 Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Fri, 21 Nov 2025 16:40:07 +0000 Subject: [PATCH 11/17] [BACKEND] Support clamp optimization on scalars (#8796) While poking around in this code, I noticed this optimization only supports tensors. This PR generalizes it to work on scalars as well. --- test/Conversion/tritongpu_to_llvm_hopper.mlir | 14 +++++ .../ElementwiseOpToLLVM.cpp | 52 +++++++++++-------- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index b8cec1dede..017e08b39d 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -229,6 +229,20 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- +// CHECK-LABEL: clamp_scalar +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func public @clamp_scalar(%x : f32, %limit : f32) { + %cst = arith.constant 0.000000e+00 : f32 + %neg_limit = arith.subf %cst, %limit : f32 + + // CHECK: nvvm.fmin.xorsign.abs.f + %12 = tt.clampf %x, %neg_limit, %limit, propagateNan = none : f32 + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [0, 1]}> #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> // CHECK-LABEL: convert_mma_to_blocked diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index 2204e8caa0..3cea4e1844 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -671,7 +671,11 @@ struct ClampFOpConversion computeCapability(computeCapability) {} bool isClipPattern(ClampFOp op) const { - bool xorsignAbsAvailable = (computeCapability >= 90); + // min.xorsign.abs requires hopper or newer + if (computeCapability < 90) { + return false; + } + // Pattern matching the sequence of clamp(x, -limit, limit) to generate // more efficient PTX code. NOTE: This pattern matching is not general // enough, but it is sufficient. We detect only two cases here: @@ -684,38 +688,40 @@ struct ClampFOpConversion // %cst_6 = arith.constant dense<-6.0000e+00> // %cst_7 = arith.constant dense<6.0000e+00> // %160 = tt.clamp %158, %cst_6, %cst_7 - bool patternFound = false; auto getSplatInitializer = [](Value v) -> std::optional { - if (auto constOp = v.getDefiningOp()) { - if (auto attr = mlir::dyn_cast( - constOp.getValueAttr())) { - if (attr.isSplat()) { - return attr.getSplatValue().convertToDouble(); - } + DenseIntOrFPElementsAttr denseAttr; + if (matchPattern(v, m_Constant(&denseAttr))) { + if (denseAttr.isSplat()) { + return denseAttr.getSplatValue().convertToDouble(); } + return std::nullopt; + } + FloatAttr floatAttr; + if (matchPattern(v, m_Constant(&floatAttr))) { + return floatAttr.getValue().convertToDouble(); } return std::nullopt; }; - if (xorsignAbsAvailable) { - if (auto subOp = op.getOperand(1).getDefiningOp()) { - if (subOp.getOperand(1) == op.getOperand(2)) { - auto initializer = getSplatInitializer(subOp.getOperand(0)); - if (initializer.has_value() && initializer.value() == 0.0) { - patternFound = true; - } - } - } else { - auto initializer1 = getSplatInitializer(op.getOperand(1)); - auto initializer2 = getSplatInitializer(op.getOperand(2)); - if (initializer1.has_value() && initializer2.has_value() && - initializer1.value() == -initializer2.value()) { - patternFound = true; + // clampf %x (sub 0.0 %max) %max + if (auto subOp = op.getOperand(1).getDefiningOp()) { + if (subOp.getOperand(1) == op.getOperand(2)) { + auto initializer = getSplatInitializer(subOp.getOperand(0)); + if (initializer.has_value() && initializer.value() == 0.0) { + return true; } } } - return patternFound; + + // clampf %x, %min, %max (where min = -max = constant) + auto initializer1 = getSplatInitializer(op.getOperand(1)); + auto initializer2 = getSplatInitializer(op.getOperand(2)); + if (initializer1.has_value() && initializer2.has_value() && + initializer1.value() == -initializer2.value()) { + return true; + } + return false; } SmallVector emitOptimization(ClampFOp op, From 96bba6b92f37a2b583d8f0003aa017a64f9d854f Mon Sep 17 00:00:00 2001 From: xiaohuguo2023 <149615094+xiaohuguo2023@users.noreply.github.com> Date: Fri, 21 Nov 2025 16:47:23 +0000 Subject: [PATCH 12/17] [AMD] Replace usage of llvm copysign intrinsic (#8789) Use FCmp + Select + FMul instead of llvm.copysign.f32. This avoids some perf regressions. --- .../TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp index 75f04163d6..9437cc0e0e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp @@ -142,13 +142,19 @@ class CallOpConversion : public OpRewritePattern { LLVM::FSubOp::create(rewriter, loc, rewriter.getF32Type(), one, ratio->getResult(0), defaultFlags); - // Apply the sign of the original input using copysign + // Apply the sign of the original input without using copysign intrinsic // tanh(x) = sign(x) * (1 - 2/(e^(2*|x|) + 1)) - const char *intrinsic = "llvm.copysign.f32"; - auto args = - llvm::SmallVector{posResult->getResult(0), operands[0]}; - replacementOp = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, - returnType, args); + // Use FCmp + Select + FMul instead of copysign to avoid potential LLVM + // optimization side effects that may affect other operations + auto zero = LLVM::createConstantF32(loc, rewriter, 0.0); + auto negOne = LLVM::createConstantF32(loc, rewriter, -1.0); + auto isNegative = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::olt, operands[0], zero); + auto sign = LLVM::SelectOp::create(rewriter, loc, rewriter.getF32Type(), + isNegative, negOne, one); + replacementOp = LLVM::FMulOp::create(rewriter, loc, returnType, + posResult->getResult(0), + sign->getResult(0), defaultFlags); } if (replacementOp) { From 4cf9906efe3dcd77eac1561c8b7e55140cfb7da1 Mon Sep 17 00:00:00 2001 From: Ravil Dorozhinskii Date: Fri, 21 Nov 2025 20:40:40 +0100 Subject: [PATCH 13/17] [AMD] Extended membar analysis with third_party ops using a trait (#8798) This PR adds `MemWaitOpTrait` trait which is used to identify all wait instructions operating on the memory. This allows to treat wait-operations from the third party dialects in the membar analysis in the same way as the native ones. This removes a workaround from AMDGPU backend. --- include/triton/Dialect/TritonGPU/IR/Traits.h | 6 ++++++ .../Dialect/TritonGPU/IR/TritonGPUAttrBase.td | 1 + .../triton/Dialect/TritonGPU/IR/TritonGPUOps.td | 2 +- .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 2 +- lib/Analysis/Membar.cpp | 2 +- .../Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td | 5 ++--- .../amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp | 16 ---------------- .../amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h | 7 ------- .../lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 1 - 9 files changed, 12 insertions(+), 30 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Traits.h b/include/triton/Dialect/TritonGPU/IR/Traits.h index 9867c287f1..03ad522236 100644 --- a/include/triton/Dialect/TritonGPU/IR/Traits.h +++ b/include/triton/Dialect/TritonGPU/IR/Traits.h @@ -22,6 +22,12 @@ class LocalLoadTrait // Optional: Add methods or verification logic here }; +template +class MemWaitOpTrait + : public mlir::OpTrait::TraitBase { + // Optional: Add methods or verification logic here +}; + } // namespace OpTrait } // namespace mlir diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td index fa0d582b7b..4b37e0c8ba 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td @@ -14,6 +14,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" // Traits used across several attrs. def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">; def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">; +def MemWaitOpTrait : NativeOpTrait<"MemWaitOpTrait">; // Common parameter helpers. def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout", diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 04943280fc..e7ed918eb0 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -42,7 +42,7 @@ def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; } -def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { +def TTG_AsyncWaitOp : TTG_Op<"async_wait", [MemWaitOpTrait]> { let summary = "async wait"; let arguments = (ins Variadic:$asyncToken, I32Attr:$num); diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index b43d5b4b3f..fde58b6bfc 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -401,7 +401,7 @@ def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter"> { let hasVerifier = 1; } -def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> { +def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait", [MemWaitOpTrait]> { let summary = "wait until all the inputs are read."; let arguments = (ins I32Attr:$pendings); let description = [{ diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 400efe58a9..d06f3c2e99 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -171,7 +171,7 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, return; } - if (isa(op) && + if (op->hasTrait() && !isa(op->getNextNode())) { // If the current op is an async wait and the next op is not a barrier we // insert a barrier op and sync diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 2250f220c3..181190c413 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -784,7 +784,7 @@ def AsyncTDMCopyLocalToGlobalOp : TT_AMDGPU_Op<"async_tdm_copy_local_to_global"> // AsyncTDMWait //===----------------------------------------------------------------------===// -def AsyncTDMWait : TT_AMDGPU_Op<"async_tdm_wait"> { +def AsyncTDMWait : TT_AMDGPU_Op<"async_tdm_wait", [MemWaitOpTrait]> { let summary = "Wait until there are less than or equal to the given number of outstanding TDM operations"; let arguments = (ins Variadic:$asyncToken, I32Attr:$num); let description = [{ @@ -793,7 +793,6 @@ def AsyncTDMWait : TT_AMDGPU_Op<"async_tdm_wait"> { necessary to ensure that data is available in the LDS before it is used. }]; let results = (outs TTG_AsyncToken:$retToken); - let assemblyFormat = "$asyncToken attr-dict"; } @@ -801,7 +800,7 @@ def AsyncTDMWait : TT_AMDGPU_Op<"async_tdm_wait"> { // AsyncWait //===----------------------------------------------------------------------===// -def AsyncWaitOp : TT_AMDGPU_Op<"async_wait"> { +def AsyncWaitOp : TT_AMDGPU_Op<"async_wait", [MemWaitOpTrait]> { let summary = "Wait until there are less than or equal to the given number of outstanding async intrinsics"; let description = [{ Similar to ttg.async_wait but instead of waiting on oustanding ttg.async_commit_groups diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp index f5cc7fd982..d73606984a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp @@ -51,22 +51,6 @@ bool comesFromAsyncWait(Value token) { } } // namespace -void addLocalBarrierAfterAmdGpuAsyncWait(ModuleOp mod) { - auto *ctx = mod->getContext(); - - SmallVector waits; - mod->walk([&waits](amdgpu::AsyncWaitOp waitOp) { waits.push_back(waitOp); }); - - IRRewriter builder(mod.getContext()); - for (auto waitOp : waits) { - if (isa(waitOp->getNextNode())) - continue; - - builder.setInsertionPointAfter(waitOp); - builder.create(waitOp->getLoc()); - } -} - void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod) { auto *ctx = mod->getContext(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h index d5206a43f7..174850e415 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h @@ -8,13 +8,6 @@ namespace mlir::triton::AMD { class TargetInfo; - -// Walks the module and adds a LocalBarrier after any amdg.async_wait if there -// is not already a barrier following it. This mimicks what Member does for -// common async wait operations and avoids AMD specific modifications to Membar. -// This yields to the same behaviour compared to when membar adds the barrier. -void addLocalBarrierAfterAmdGpuAsyncWait(ModuleOp mod); - // Annotates LocalLoadOps with ttg.amdg.syncedByAsyncWait=true if they are // synced by an AsyncWait. void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index c4bb2dab1c..dfd2157e5a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -130,7 +130,6 @@ struct ConvertTritonAMDGPUToLLVM if (targetInfo.requiresAliasInfoForAsyncOps()) AMD::annotateLocalLoadsSyncedViaAsyncWait(mod); - AMD::addLocalBarrierAfterAmdGpuAsyncWait(mod); ModuleMembarAnalysis membarPass(&allocation, mlir::triton::AMD::membarFilter); membarPass.run(); From 046ab0e218d83388855368242b421dfd802a1141 Mon Sep 17 00:00:00 2001 From: neildhar Date: Fri, 21 Nov 2025 12:17:54 -0800 Subject: [PATCH 14/17] [NFC] Simplify populating axisinfo map (#8800) A default constructed `AxisInfo` passed as an operand to `AxisInfo::join` will always result in `join` returning the other operand. This means that we can call `join` unconditionally even when there is no existing entry in the map. This collapses the three separate map lookups (the check, the join, and the population) to just a single one. --- lib/Analysis/AxisInfo.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 50ded51aa3..ead683cf42 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1350,13 +1350,8 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp, // pessimistic state. if (axisInfo.getRank() == 0) axisInfo = AxisInfo::getPessimisticValueState(value); - AxisInfo curAxisInfo; - if (axisInfoMap->count(value)) { - curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value)); - } else { - curAxisInfo = axisInfo; - } - (*axisInfoMap)[value] = curAxisInfo; + auto &valInfo = (*axisInfoMap)[value]; + valInfo = AxisInfo::join(axisInfo, valInfo); }; funcOp.walk([&](Operation *op) { for (auto value : op->getResults()) { From 4b184cccbb42c3aa5f24e34c39c5358636b3e2bd Mon Sep 17 00:00:00 2001 From: Evghenii <1485713+3gx@users.noreply.github.com> Date: Fri, 21 Nov 2025 14:17:35 -0800 Subject: [PATCH 15/17] patch workaround by correctly setting stage/cluster attrubtes (#8797) * patches workaround for loop-scheduler by using stage/cluster from previous tmem access op in the partition to set stage/cluster for put.exit op, and if needed for the follow-up put.enter op --- test/NVWS/aref-tmem-insertion.mlir | 54 +++++++++++++++++++ .../NVWS/Transforms/InsertTmemAref.cpp | 38 ++++++++----- 2 files changed, 78 insertions(+), 14 deletions(-) diff --git a/test/NVWS/aref-tmem-insertion.mlir b/test/NVWS/aref-tmem-insertion.mlir index 00c964ec20..a8692974c3 100644 --- a/test/NVWS/aref-tmem-insertion.mlir +++ b/test/NVWS/aref-tmem-insertion.mlir @@ -788,3 +788,57 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory +#tmem = #ttng.tensor_memory_encoding +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { + // CHECK-LABEL: @if_split_workaround + tt.func @if_split_workaround(%arg0: !tt.tensordesc>, %arg1: tensor<64x128x!tt.ptr, #blocked3> {tt.contiguity = dense<[1, 64]> : tensor<2xi32>, tt.divisibility = dense<16> : tensor<2xi32>}) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %true = arith.constant true + %false = arith.constant false + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c32_i32 = arith.constant 32 : i32 + %result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) + %0 = ttng.tmem_store %cst, %result[%token], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + // CHECK: scf.for + %1:3 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %true, %arg4 = %arg1, %arg5 = %0) -> (i1, tensor<64x128x!tt.ptr, #blocked3>, !ttg.async.token) : i32 { + %2:3 = "get_offsets"(%arg2) {loop.cluster = 3 : i32, loop.stage = 0 : i32, ttg.partition = array} : (i32) -> (i32, tensor<64x128xi32, #blocked3>, i32) + %3 = tt.splat %2#0 {loop.cluster = 3 : i32, loop.stage = 0 : i32, ttg.partition = array} : i32 -> tensor<128xi32, #blocked2> + %4 = tt.descriptor_gather %arg0[%3, %2#2] {loop.cluster = 3 : i32, loop.stage = 0 : i32, ttg.partition = array} : (!tt.tensordesc>, tensor<128xi32, #blocked2>, i32) -> tensor<128x64xf16, #blocked1> + %5 = tt.addptr %arg4, %2#1 {loop.cluster = 3 : i32, loop.stage = 1 : i32, tt.constancy = dense<1> : tensor<2xi32>, tt.contiguity = dense<[1, 64]> : tensor<2xi32>, tt.divisibility = dense<16> : tensor<2xi32>, ttg.partition = array} : tensor<64x128x!tt.ptr, #blocked3>, tensor<64x128xi32, #blocked3> + %6 = tt.load %5 {loop.cluster = 3 : i32, loop.stage = 1 : i32, ttg.partition = array} : tensor<64x128x!tt.ptr, #blocked3> + %7 = ttg.local_alloc %4 {loop.cluster = 2 : i32, loop.stage = 2 : i32, ttg.partition = array} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %8 = ttg.local_alloc %6 {loop.cluster = 2 : i32, loop.stage = 2 : i32, ttg.partition = array} : (tensor<64x128xf16, #blocked3>) -> !ttg.memdesc<64x128xf16, #shared, #smem> + // CHECK: tc_gen5_mma {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32 + %9 = ttng.tc_gen5_mma %7, %8, %result[%arg5], %arg3, %true {loop.cluster = 2 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %10 = arith.cmpi eq, %arg2, %c0_i32 {loop.cluster = 1 : i32, loop.stage = 3 : i32, ttg.partition = array} : i32 + %11 = arith.select %10, %false, %true {loop.cluster = 1 : i32, loop.stage = 3 : i32, ttg.partition = array} : i1 + // CHECK: scf.if + // CHECK-NEXT: put.exit {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32 + // CHECK} {loop.cluster = 2 : i32, loop.stage = 2 : i32 + // CHECK: scf.if + // CHECK: } {loop.cluster = 4 : i32, loop.stage = 3 : i32 + // CHECK: scf.if + // CKECK-NEXT: put.enter {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32 + // CHECK: } {loop.cluster = 2 : i32, loop.stage = 2 : i32 + %12 = scf.if %10 -> (!ttg.async.token) { + %result_0, %token_1 = ttng.tmem_load %result[%9] {ttg.partition = array} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> + "acc_user"(%result_0) {ttg.partition = array} : (tensor<128x128xf32, #blocked>) -> () + scf.yield {ttg.partition = array} %token_1 : !ttg.async.token + } else { + scf.yield {ttg.partition = array} %9 : !ttg.async.token + } {loop.cluster = 4 : i32, loop.stage = 3 : i32, ttg.partition = array, ttg.partition.outputs = [array]} + scf.yield {ttg.partition = array} %11, %5, %12 : i1, tensor<64x128x!tt.ptr, #blocked3>, !ttg.async.token + } {tt.disallow_acc_multi_buffer, tt.num_stages = 3 : i32, tt.scheduled_max_stage = 3 : i32, tt.warp_specialize, ttg.partition = array, ttg.partition.outputs = [array, array, array], ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 2 : i32} + tt.return + } +} diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp index ce4c37075f..affdd768e5 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp @@ -402,10 +402,15 @@ struct TMEMAref { token = op.getToken(); } partitionId = paritionIdStageCluster.first; + if (partitionId) + stageClusters[*partitionId] = paritionIdStageCluster.second; buffer = {}; } - void release(OpBuilder &b, Location loc, StageCluster stageCluster) { + void release(OpBuilder &b, Location loc) { assert(asyncOp); + StageCluster stageCluster; + if (partitionId) + stageCluster = stageClusters[*partitionId]; if (kind == PUT) { createInto( b, loc, {partitionId, stageCluster}, aref, token, @@ -447,6 +452,7 @@ struct TMEMAref { Kind kind; std::optional partitionId; std::optional asyncOp; + DenseMap stageClusters; }; TmemAccessDag::Node * @@ -458,12 +464,10 @@ insertTmemArefImpl(TmemAccessDag::Node *node, if (curPartitionId && node->partitionId != curPartitionId) { OpBuilder b(node->op); Operation *prevOp = nullptr; - StageCluster prevStageCluster; if (node->parent) { // release right after the last op which owns the tmem prevOp = node->parent->op; b.setInsertionPointAfter(prevOp); - prevStageCluster = getStageCluster(prevOp); } else { // if we are inside if-stmt or for-stmt subdag and need to change // ownerhip, release at the top of the block @@ -471,12 +475,7 @@ insertTmemArefImpl(TmemAccessDag::Node *node, prevOp = node->parentDag->op; b.setInsertionPointToStart(node->op->getBlock()); } - if (!node->partitionId) { - // if node->partitionId is not set, it means we are outside ws-region - // reset prevPartitionId and prevStageCluster to defaults - prevStageCluster = {}; - } - state.release(b, prevOp->getLoc(), prevStageCluster); + state.release(b, prevOp->getLoc()); // acquire right before op that acquires ownership of tmem auto curOp = node->op; @@ -489,6 +488,10 @@ insertTmemArefImpl(TmemAccessDag::Node *node, curOp = node->parentDag->op; } auto stageCluster = getStageCluster(curOp); + // if stage-cluster is empty, use the stage-cluster used from the last op + // that acquired ownership of tmem in a partition + if (!stageCluster && partitionId) + stageCluster = state.stageClusters[*partitionId]; state.acquire(b, curOp->getLoc(), {partitionId, stageCluster}); } @@ -519,16 +522,22 @@ insertTmemArefImpl(TmemAccessDag::Node *node, OpBuilder b(node->op); if (auto tmemLoadOp = dyn_cast(node->op)) { + if (auto id = node->partitionId) + state.stageClusters[*id] = getStageCluster(node->op); tmemLoadOp.getSrcMutable().assign( state.getBuffer(b, node->partitionId, node->op)); tmemLoadOp.getDepMutable().clear(); tmemLoadOp.getToken().replaceAllUsesWith(state.replToken); } else if (auto tmemStoreOp = dyn_cast(node->op)) { + if (auto id = node->partitionId) + state.stageClusters[*id] = getStageCluster(node->op); tmemStoreOp.getDstMutable().assign( state.getBuffer(b, node->partitionId, node->op)); tmemStoreOp.getDepMutable().clear(); tmemStoreOp.getToken().replaceAllUsesWith(state.replToken); } else if (auto mmaOp = dyn_cast(node->op)) { + if (auto id = node->partitionId) + state.stageClusters[*id] = getStageCluster(node->op); if (mmaOp.getAccumulator() == state.origBuffer) { mmaOp.getAccDepMutable().clear(); mmaOp.getToken().replaceAllUsesWith(state.replToken); @@ -640,10 +649,11 @@ LogicalResult insertTmemAref(TmemAccessDag &accessDag) { // aref is used outside ws-loop, find the last point in the same block as // create op to have matching exit auto op1 = arefOp->getBlock()->findAncestorOpInBlock(*node->op); + if (auto id = node->partitionId) + state.stageClusters[*id] = {}; b.setInsertionPointAfter(op1); } - stageCluster = getStageCluster(node->op); - state.release(b, node->op->getLoc(), stageCluster); + state.release(b, node->op->getLoc()); if (state.kind == TMEMAref::GET) { // When the state ends up in a GET operation, we need to acquire and release @@ -661,7 +671,7 @@ LogicalResult insertTmemAref(TmemAccessDag &accessDag) { } } state.acquire(b, node->op->getLoc(), {otherPartitionId, {}}); - state.release(b, node->op->getLoc(), {}); + state.release(b, node->op->getLoc()); } return success(); @@ -751,8 +761,8 @@ void workaroundForLoopScheduler(triton::FuncOp funcOp) { // patch loop.stage=1 enterIf->setAttrs(ifOp->getAttrs()); exitIf->setAttrs(ifOp->getAttrs()); - enterIf->setAttr(kLoopStageAttrName, b.getI32IntegerAttr(1)); - exitIf->setAttr(kLoopStageAttrName, b.getI32IntegerAttr(1)); + assignStage(b, enterIf, getStageCluster(putEnterOp)); + assignStage(b, exitIf, getStageCluster(putExitOp)); SetVector enterExitIds, middleIds; enterExitIds.insert(1); From 29009f1b136b738d354ffcb4e89c4bd3f2343832 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Sat, 22 Nov 2025 00:00:47 +0000 Subject: [PATCH 16/17] [GLUON] Allow TensorMemory layouts in `to_linear_layout` in the context of printing. (#8682) --- python/src/gluon_ir.cc | 43 ++++++++++++++- python/test/gluon/test_frontend.py | 53 ++++++------------- .../experimental/gluon/language/_layouts.py | 10 ++++ .../experimental/gluon/language/_semantic.py | 19 +++---- .../language/nvidia/blackwell/__init__.py | 22 ++++++++ 5 files changed, 100 insertions(+), 47 deletions(-) diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index c4009bb424..4ea5d4130b 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -375,8 +375,47 @@ void init_gluon_ir(py::module &&m) { std::vector &shape) -> py::object { auto ctx = self.getContext(); auto linearLayout = ttg::toLinearLayout(shape, layout); - auto attr = ttg::LinearEncodingAttr::get(ctx, linearLayout); - return layoutToGluon(attr); + + if (isa(layout)) { + auto attr = ttg::LinearEncodingAttr::get(ctx, linearLayout); + return layoutToGluon(attr); + } + if (isa(layout)) { + auto alignment = + cast(layout).getAlignment(); + auto attr = ttg::SharedLinearEncodingAttr::get(ctx, linearLayout, + alignment); + return layoutToGluon(attr); + } + + // TensorMemory encodings: keep the LinearLayout but wrap as + // print-only Python object carrying row/col bases -> dim0/dim1. + auto inNamesRange = linearLayout.getInDimNames(); + auto inNames = llvm::to_vector(inNamesRange); + bool isTmemLayout = + (inNames.size() == 2 && inNames[0].str() == "row" && + inNames[1].str() == "col"); + if (!isTmemLayout) + throw std::invalid_argument( + "Unsupported layout in to_linear_layout"); + + // Build Py _TensorMemoryLinearLayout(row_bases, col_bases, shape, + // repr) + py::object tmemCls = + py::module::import( + "triton.experimental.gluon.language.nvidia.blackwell") + .attr("_TensorMemoryLinearLayout"); + auto bases = linearLayout.getBases(); + auto rowBases = bases[mlir::StringAttr::get(ctx, "row")]; + auto colBases = bases[mlir::StringAttr::get(ctx, "col")]; + auto outDims = linearLayout.getOutDims(); + std::vector shapeVec; + for (auto &od : outDims) + shapeVec.push_back(od.second); + + py::object pyObj = tmemCls(py::cast(rowBases), py::cast(colBases), + py::cast(shapeVec)); + return pyObj; }) .def("get_dot_operand_layout", [](GluonOpBuilder &self, unsigned opIdx, Attribute parent, diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 204cb10e0c..aa4b1b2fed 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -1461,48 +1461,29 @@ def kernel(reg_type: ttgl.constexpr, shared_type: ttgl.constexpr, ref_conflicts: @pytest.mark.parametrize( - "layout, expected", + "layout, shape", [ - ( - ttgl.BlockedLayout([1], [4], [4], [0]), - ttgl.DistributedLinearLayout( - reg_bases=[], - lane_bases=[[1], [2]], - warp_bases=[[4], [8]], - block_bases=[], - shape=[16], - ), - ), - ( - ttgl.BlockedLayout([1], [4], [4], [0], [[1], [0]]), - ttgl.DistributedLinearLayout( - reg_bases=[], - lane_bases=[[1], [2]], - warp_bases=[[4], [8]], - block_bases=[[16], [0]], - shape=[32], - ), - ), - ( - ttgl.BlockedLayout([8, 1], [8, 4], [1, 4], [0, 1], [[0, 1]]), - ttgl.DistributedLinearLayout( - reg_bases=[[1, 0], [2, 0], [4, 0], [0, 16], [0, 32]], - lane_bases=[[8, 0], [16, 0], [32, 0], [0, 1], [0, 2]], - warp_bases=[[0, 4], [0, 8]], - block_bases=[[0, 64]], - shape=[64, 128], - ), - ), + (ttgl.BlockedLayout([1], [4], [4], [0]), [16]), + (ttgl.BlockedLayout([1], [4], [4], [0], [[1], [0]]), [32]), + (ttgl.BlockedLayout([8, 1], [8, 4], [1, 4], [0, 1], [[0, 1]]), [64, 128]), + (ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2), [64, 64]), + (TensorMemoryLayout((64, 64), col_stride=2), [64, 64]), ], ) -def test_to_linear_layout(layout, expected): +def test_to_linear_layout(layout, shape, capsys): @gluon.jit - def kernel(layout: ttgl.constexpr, expected: ttgl.constexpr, shape: ttgl.constexpr): + def kernel(layout: ttgl.constexpr, shape: ttgl.constexpr): computed: ttgl.constexpr = ttgl.to_linear_layout(layout, shape) - ttgl.static_assert(computed == expected) - - run_parser(kernel, args=(layout, expected, tuple(expected.shape)), target=AMPERE_TARGET) + ttgl.static_print(computed) + + run_parser(kernel, args=(layout, tuple(shape)), target=AMPERE_TARGET) + out = capsys.readouterr().out + if isinstance(layout, TensorMemoryLayout): + assert "rows=" in out + assert "cols=" in out + else: + assert "DistributedLinearLayout" in out or "SharedLinearLayout" in out @filecheck_test diff --git a/python/triton/experimental/gluon/language/_layouts.py b/python/triton/experimental/gluon/language/_layouts.py index 7f5a2c4002..b876b2f6b4 100644 --- a/python/triton/experimental/gluon/language/_layouts.py +++ b/python/triton/experimental/gluon/language/_layouts.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +import itertools from typing import List from triton.language.core import _unwrap_if_constexpr, _unwrap_shape, constexpr_type @@ -636,6 +637,15 @@ def _to_ir(self, builder): def mangle(self) -> str: return f"SharedLinear_{self.offset_bases}_{self.block_bases}_{self.alignment}_SharedLinear" + @property + def shape(self): + rank = len(self.offset_bases[0]) + max_stride = [1] * rank + for b in itertools.chain(self.offset_bases, self.block_bases): + for i, bi in enumerate(b): + max_stride[i] = max(max_stride[i], bi) + return [2 * s for s in max_stride] + def __hash__(self): return hash(( tuple(map(tuple, self.offset_bases)), diff --git a/python/triton/experimental/gluon/language/_semantic.py b/python/triton/experimental/gluon/language/_semantic.py index ec019cbe4a..8843005155 100644 --- a/python/triton/experimental/gluon/language/_semantic.py +++ b/python/triton/experimental/gluon/language/_semantic.py @@ -2,7 +2,7 @@ import math from triton.language.semantic import TritonSemantic from . import _core as ttgl -from ._layouts import AutoLayout, DistributedLayout, DistributedLinearLayout, SliceLayout, SharedLayout, CoalescedLayout +from ._layouts import AutoLayout, DistributedLayout, DistributedLinearLayout, SliceLayout, SharedLayout, CoalescedLayout, SharedLinearLayout from triton._C.libtriton.gluon_ir import GluonOpBuilder, compute_tmem_reg_layout from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values @@ -301,15 +301,16 @@ def bank_conflicts(self, distr_ty, shared_ty): distr_ty.element_ty.primitive_bitwidth) def to_linear_layout(self, layout, shape): - _check(isinstance(layout, (DistributedLayout, SharedLayout)), - lambda: f"Expected a DistributedLayout or SharedLayout, got {type(layout)}") - - if not isinstance(shape, list): - shape = list(shape) - - layout = ttgl._unwrap_if_constexpr(layout) + from triton.experimental.gluon.language.nvidia.blackwell import ( + TensorMemoryLayout, + TensorMemoryScalesLayout, + ) + _check( + isinstance(layout, (DistributedLayout, SharedLayout, TensorMemoryLayout, TensorMemoryScalesLayout)), lambda: + f"Expected a DistributedLayout, SharedLayout, or TensorMemoryLayout or TensorMemoryScalesLayout, got {type(layout)}" + ) - if isinstance(layout, (AutoLayout, DistributedLinearLayout)): + if isinstance(layout, (AutoLayout, DistributedLinearLayout, SharedLinearLayout)): return ttgl.constexpr(layout) return ttgl.constexpr(self.builder.to_linear_layout(layout._to_ir(self.builder), shape)) diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py index 6d1b21c011..16650e8743 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple, List, TYPE_CHECKING from dataclasses import dataclass +import itertools from triton.runtime.jit import constexpr_function from triton.experimental.gluon.language import _core as ttgl from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr @@ -26,7 +27,9 @@ "mma_v2", "tensor_memory_descriptor", "TensorMemoryLayout", + "TensorMemoryScalesLayout", "tma", + "_TensorMemoryLinearLayout", ] @@ -104,6 +107,25 @@ def __hash__(self): return hash(self.cta_split_num) +@dataclass(frozen=True) +class _TensorMemoryLinearLayout: + """ + Print-only linear layout for TMEM (row/col -> dim0/dim1). + """ + rows: List[List[int]] + cols: List[List[int]] + shape: List[int] + + def _to_ir(self, builder): + raise RuntimeError("TensorMemoryLinearLayout is print-only; IR materialization is unsupported") + + def mangle(self): + return f"TMLL_{self.shape}_TMLL" + + def __hash__(self): + return hash((tuple(map(tuple, self.rows)), tuple(map(tuple, self.cols)), tuple(self.shape))) + + @constexpr_function def get_tmem_reg_layout( element_ty, From 546a718ab0154b05674cd2ab3c4378af7658e698 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 1 Dec 2025 11:16:01 +0000 Subject: [PATCH 17/17] Revert "[Reland] Fix handling of unvisited operands in AxisInfoAnalysis (#8758)" This reverts commit 31281bcdd769d79519887cd6ba31570e1e2f6ee2. --- lib/Analysis/AxisInfo.cpp | 34 +++++++++++++++++-------------- test/Analysis/test-alignment.mlir | 15 -------------- 2 files changed, 19 insertions(+), 30 deletions(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 6aa84e6b7e..22889e46f9 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1079,10 +1079,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver, LogicalResult AxisInfoAnalysis::visitOperation( Operation *op, ArrayRef *> operands, ArrayRef *> results) { - // If any operands are not yet ready, skip this operation for now. + // TODO: For sure not the right way to do this + // but why is scf.if not initialized otherwise? for (auto op : operands) if (op->getValue().getRank() == 0) - return success(); + setToEntryState((dataflow::Lattice *)op); AxisInfo curr = visitors.apply(op, operands); if (curr.getRank() == 0) { setAllToEntryStates(results); @@ -1111,11 +1112,9 @@ void AxisInfoAnalysis::visitForOpInductionVar( ProgramPoint *programPoint = getProgramPointAfter(op); auto *lbLattice = getLatticeElementFor(programPoint, op.getLowerBound()); auto *stepLattice = getLatticeElementFor(programPoint, op.getStep()); - // If lb or step is not yet ready, skip this operation for now. - if (lbLattice->getValue().getRank() == 0 || - stepLattice->getValue().getRank() == 0) { - return; - } + for (auto op_iter : {lbLattice, stepLattice}) + if (op_iter->getValue().getRank() == 0) + setToEntryState((dataflow::Lattice *)op_iter); AxisInfo::DimVectorT knownContiguity(1, 1); AxisInfo::DimVectorT knownDivisibility(1, 1); @@ -1189,15 +1188,24 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) { initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, &knownContiguity, &knownDivisibility, &knownConstancy); - } else if (isa(op)) { - // Initialize the arguments to gpu::WarpSpecializePartitionsOp with - // "unknown" state: the maximum possible divisibility, contiguity, and - // constancy. + } else if (isa( + op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp + // Control flow operations are initialized with "unknown" state: + // the maximum possible divisibility, contiguity, and constancy. knownDivisibility = DimVectorT(rank, kMaxDivisor); knownConstancy = DimVectorT(rank, kMaxDivisor); knownContiguity = DimVectorT(rank, kMaxDivisor); } } else if (Operation *op = value.getDefiningOp()) { + if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state: + // the maximum possible divisibility, contiguity, and constancy. + knownDivisibility = DimVectorT(rank, kMaxDivisor); + knownConstancy = DimVectorT(rank, kMaxDivisor); + knownContiguity = DimVectorT(rank, kMaxDivisor); + } // Other operations are conservatively initialized with the lowest possible // divisibility, contiguity, and constancy unless they have specified. AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"), @@ -1350,10 +1358,6 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp, auto *axisInfoMap = getFuncData(funcOp); auto updateAxisInfoMap = [&](Value value) { auto axisInfo = analysis->getLatticeElement(value)->getValue(); - // If we could not determine the AxisInfo for this value, assume the - // pessimistic state. - if (axisInfo.getRank() == 0) - axisInfo = AxisInfo::getPessimisticValueState(value); auto &valInfo = (*axisInfoMap)[value]; valInfo = AxisInfo::join(axisInfo, valInfo); }; diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 7f17dca4f9..49821e969f 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -1089,18 +1089,3 @@ tt.func public @test_inductor_for() { } tt.return } - -// ----- - -// Verify that if an operation is statically determined to be dead, we fall back -// to assigning it a pessimistic value, rather than skipping it entirely. -tt.func @dead_op_pessimistic() { - %c5 = arith.constant dense<5> : tensor<4xi32> - %c7 = arith.constant dense<7> : tensor<4xi32> - %false = arith.constant false - scf.if %false { - // expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} - %add = arith.addi %c5, %c7 : tensor<4xi32> - } - tt.return -}