Skip to content

Commit 44abc7a

Browse files
authored
[AMD][LLVM] Scalarize packed fops in the same mfma/wmma block (#6656)
This PR adds an _LLVM Pass_ that scalarizes vector `fmul`s and `fadd`s in basic blocks that contain MFMAs/WMMAs. The point/purpose/value of doing this is these instructions get codegened to "packed" ops (`v_pk_mul_f32`/`v_pk_add_f32`), which cannot be co-issued with mfma, thus there is a performance cost. Concretely/specifically this eliminates `v_pk_mul_f32`/`v_pk_add_f32` operations in the final asm in bbs with MFMAs. Note, these "scalar" floating point ops will still get lowered to vector instructions like `v_mul_f32_e32` and `v_add_u32_e32`, just not the "packed" variants. Note, these packed fops aren't actually emitted by triton per se - they are introduced/inserted by the `VectorCombine::foldPermuteOfBinops` pattern during the `optimize_module` pipeline (hence why this LLVM pass needs to follow that pipeline).
1 parent 1d6b7dd commit 44abc7a

File tree

10 files changed

+453
-1
lines changed

10 files changed

+453
-1
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ jobs:
110110
fi
111111
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
112112
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py
113+
TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py
113114
cd python/test/unit
114115
pytest --capture=tee-sys -rfs -n 12 language runtime \
115116
--ignore=language/test_line_info.py \

python/triton/knobs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ class amd_knobs(base_knobs):
446446
global_prefetch: env_int = env_int("TRITON_HIP_GLOBAL_PREFETCH")
447447
local_prefetch: env_int = env_int("TRITON_HIP_LOCAL_PREFETCH")
448448
use_async_copy: env_bool = env_bool("TRITON_HIP_USE_ASYNC_COPY")
449+
scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS")
449450

450451

451452
class proton_knobs(base_knobs):

third_party/amd/backend/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ def make_llir(src, metadata, options):
365365

366366
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)
367367

368+
if knobs.amd.scalarize_packed_fops:
369+
amd.add_scalarize_packed_fops_llvm_pass(fns[0])
370+
368371
# Get some metadata
369372
metadata["shared"] = src.get_int_attr("ttg.shared")
370373

third_party/amd/backend/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def is_active():
530530
def get_current_target(self):
531531
device = self.get_current_device()
532532
device_properties = self.utils.get_device_properties(device)
533-
arch = device_properties['arch']
533+
arch = knobs.runtime.override_arch or device_properties['arch']
534534
warp_size = device_properties['warpSize']
535535
return GPUTarget("hip", arch.split(':')[0], warp_size)
536536

third_party/amd/include/TritonAMDGPUToLLVM/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ namespace mlir::triton::AMD {
3030
/// @return created pass
3131
std::unique_ptr<OperationPass<ModuleOp>>
3232
createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0);
33+
34+
void runScalarizePackedFOpsPass(llvm::Function &F);
35+
3336
} // namespace mlir::triton::AMD
3437

3538
namespace mlir::triton {

third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,17 @@ add_triton_library(TritonAMDGPUToLLVM
2424
SchedInstructions.cpp
2525
UpcastMXFPToLLVM.cpp
2626
MembarUtility.cpp
27+
ScalarizePackedFOps.cpp
2728

2829
DEPENDS
2930
TritonAMDGPUConversionPassIncGen
31+
LLVMIRIncGen
3032

3133
LINK_LIBS PUBLIC
3234
TritonGPUToLLVM
3335
TritonAMDGPUIR
3436
TritonProtonToLLVM
37+
LLVMCore
38+
LLVMPasses
39+
LLVMSupport
3540
)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include "TritonAMDGPUToLLVM/Passes.h"
2+
#include "llvm/IR/Instructions.h"
3+
#include "llvm/IR/PatternMatch.h"
4+
#include "llvm/IR/Verifier.h"
5+
#include "llvm/Passes/PassBuilder.h"
6+
7+
#define DEBUG_TYPE "tritonamdgpu-scalarize-packed-fops"
8+
9+
using namespace llvm;
10+
using namespace llvm::PatternMatch;
11+
12+
namespace {
13+
14+
bool isMFMAorWMMA(Instruction &inst) {
15+
auto *callInst = llvm::dyn_cast<CallInst>(&inst);
16+
if (!callInst)
17+
return false;
18+
// E.g., tail call void asm sideeffect "s_waitcnt lgkmcnt(0) ", ""()
19+
if (callInst->isInlineAsm())
20+
return false;
21+
Function *calledFunc = callInst->getCalledFunction();
22+
if (!calledFunc->isIntrinsic())
23+
return false;
24+
StringRef intrinName = calledFunc->getName();
25+
if (intrinName.contains("mfma") || intrinName.contains("wmma"))
26+
return true;
27+
return false;
28+
}
29+
30+
bool maybeReplaceVectorFOpWithScalarFOps(Instruction *inst,
31+
IRBuilder<> &builder) {
32+
Value *lhs, *rhs;
33+
if (!match(inst, m_BinOp(m_Value(lhs), m_Value(rhs))))
34+
return false;
35+
auto *VecLhs = dyn_cast<VectorType>(lhs->getType());
36+
if (!VecLhs)
37+
return false;
38+
assert(!VecLhs->isScalableTy() && "expected fixed-len vector");
39+
builder.SetInsertPoint(inst);
40+
Value *newVec = llvm::UndefValue::get(VecLhs);
41+
for (int i = 0; i < VecLhs->getElementCount().getFixedValue(); ++i) {
42+
Value *newLhs = builder.CreateExtractElement(lhs, i);
43+
Value *newRhs = builder.CreateExtractElement(rhs, i);
44+
Value *res;
45+
if (inst->getOpcode() == Instruction::FMul)
46+
res = builder.CreateFMul(newLhs, newRhs);
47+
else if (inst->getOpcode() == Instruction::FAdd)
48+
res = builder.CreateFAdd(newLhs, newRhs);
49+
else if (inst->getOpcode() == Instruction::FSub)
50+
res = builder.CreateFSub(newLhs, newRhs);
51+
else
52+
llvm::report_fatal_error("only fadd, fmul, fsub supported");
53+
newVec = builder.CreateInsertElement(newVec, res, i);
54+
}
55+
LLVM_DEBUG(dbgs() << "ScalarizePackedFOps: Replacing: " << inst << '\n');
56+
LLVM_DEBUG(dbgs() << " With: " << newVec << '\n');
57+
inst->replaceAllUsesWith(newVec);
58+
return true;
59+
}
60+
61+
// This Pass scalarizes vector `fmul`s and `fadd`s in basic blocks that contain
62+
// MFMAs. The point/purpose/value of doing is that these get codegened to
63+
// "packed" ops (`v_pk_mul_f32`/`v_pk_add_f32`) and while packed ops use
64+
// separate VALUs from MFMA tensor cores (no problem there), the instructions
65+
// themselves cannot be *issued* in parallel, thus there is a performance cost
66+
// to having such packed ops "near" MFMAs. Concretely/specifically this
67+
// eliminates `v_pk_mul_f32`/`v_pk_add_f32` operations in the final asm in bbs
68+
// with MFMAs.
69+
//
70+
// Note, these "scalar" floating point ops will still get lowered to vector
71+
// instructions like `v_mul_f32_e32 v1, v163, v114` and
72+
// `v_add_u32_e32 v1, s16, v12`, just not the "packed" variants.
73+
//
74+
// Note, these vectorized `fmul`s aren't actually emitted by triton per se -
75+
// they are introduced/inserted by the VectorCombine::foldPermuteOfBinops
76+
// pattern during the `optimize_module` pipeline (hence why this LLVM pass
77+
// needs to follow that pipeline).
78+
struct ScalarizePackedFOps : FunctionPass {
79+
ScalarizePackedFOps() : FunctionPass(ID) {}
80+
81+
bool runOnFunction(Function &F) override {
82+
IRBuilder builder(F.getContext());
83+
bool changed = false;
84+
SmallVector<Instruction *> instsToErase;
85+
for (BasicBlock &BB : F) {
86+
if (!llvm::any_of(BB, isMFMAorWMMA))
87+
continue;
88+
for (Instruction &inst : BB) {
89+
if (inst.getOpcode() != Instruction::FMul &&
90+
inst.getOpcode() != Instruction::FAdd &&
91+
inst.getOpcode() != Instruction::FSub)
92+
continue;
93+
if (maybeReplaceVectorFOpWithScalarFOps(&inst, builder)) {
94+
instsToErase.push_back(&inst);
95+
changed = true;
96+
}
97+
}
98+
}
99+
100+
if (changed) {
101+
for (Instruction *inst : instsToErase) {
102+
if (inst)
103+
inst->eraseFromParent();
104+
}
105+
}
106+
107+
// We don't do anything with this but this is a virtual function override
108+
// and the signature requires it.
109+
return changed;
110+
}
111+
112+
static char ID;
113+
};
114+
115+
} // end anonymous namespace
116+
117+
char ScalarizePackedFOps::ID = 0;
118+
119+
namespace mlir::triton::AMD {
120+
void runScalarizePackedFOpsPass(Function &F) {
121+
ScalarizePackedFOps pass;
122+
pass.runOnFunction(F);
123+
// If there are no errors, the function returns false.
124+
assert(!llvm::verifyFunction(F) &&
125+
"expected function to verify successfully");
126+
}
127+
} // namespace mlir::triton::AMD
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
module {
2+
tt.func public @attn_fwd(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32 {tt.divisibility = 16 : i32}, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: f32, %arg24: i32, %arg25: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg26: i32) attributes {noinline = false} {
3+
%c8192_i32 = arith.constant 8192 : i32
4+
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32>
5+
%cst_0 = arith.constant dense<0.127517432> : tensor<256xf32>
6+
%cst_1 = arith.constant dense<0.127517432> : tensor<256x64xf32>
7+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32>
8+
%c16640_i32 = arith.constant 16640 : i32
9+
%c786432_i32 = arith.constant 786432 : i32
10+
%cst_3 = arith.constant dense<0.000000e+00> : tensor<256x128xf16>
11+
%cst_4 = arith.constant dense<true> : tensor<256x128xi1>
12+
%cst_5 = arith.constant dense<1.000000e+00> : tensor<256x1xf32>
13+
%cst_6 = arith.constant dense<16384> : tensor<256x1xi32>
14+
%cst_7 = arith.constant dense<1.000000e+00> : tensor<256xf32>
15+
%cst_8 = arith.constant dense<0xFF800000> : tensor<256xf32>
16+
%c64_i32 = arith.constant 64 : i32
17+
%c16384_i32 = arith.constant 16384 : i32
18+
%c256_i32 = arith.constant 256 : i32
19+
%c1_i32 = arith.constant 1 : i32
20+
%true = arith.constant true
21+
%c0_i32 = arith.constant 0 : i32
22+
%0 = arith.cmpi sge, %arg5, %c0_i32 : i32
23+
llvm.intr.assume %0 : i1
24+
%1 = arith.cmpi sge, %arg6, %c0_i32 : i32
25+
llvm.intr.assume %1 : i1
26+
%2 = arith.cmpi sge, %arg7, %c0_i32 : i32
27+
llvm.intr.assume %2 : i1
28+
llvm.intr.assume %true : i1
29+
%3 = arith.cmpi sge, %arg8, %c0_i32 : i32
30+
llvm.intr.assume %3 : i1
31+
%4 = arith.cmpi sge, %arg9, %c0_i32 : i32
32+
llvm.intr.assume %4 : i1
33+
%5 = arith.cmpi sge, %arg10, %c0_i32 : i32
34+
llvm.intr.assume %5 : i1
35+
llvm.intr.assume %true : i1
36+
%6 = arith.cmpi sge, %arg17, %c0_i32 : i32
37+
llvm.intr.assume %6 : i1
38+
%7 = arith.cmpi sge, %arg18, %c0_i32 : i32
39+
llvm.intr.assume %7 : i1
40+
%8 = arith.cmpi sge, %arg19, %c0_i32 : i32
41+
llvm.intr.assume %8 : i1
42+
%9 = arith.cmpi sge, %arg20, %c0_i32 : i32
43+
llvm.intr.assume %9 : i1
44+
%10 = arith.cmpi sge, %arg11, %c0_i32 : i32
45+
llvm.intr.assume %10 : i1
46+
%11 = arith.cmpi sge, %arg12, %c0_i32 : i32
47+
llvm.intr.assume %11 : i1
48+
%12 = arith.cmpi sge, %arg13, %c0_i32 : i32
49+
llvm.intr.assume %12 : i1
50+
llvm.intr.assume %true : i1
51+
%13 = arith.cmpi sge, %arg14, %c0_i32 : i32
52+
llvm.intr.assume %13 : i1
53+
%14 = arith.cmpi sge, %arg15, %c0_i32 : i32
54+
llvm.intr.assume %14 : i1
55+
%15 = arith.cmpi sge, %arg16, %c0_i32 : i32
56+
llvm.intr.assume %15 : i1
57+
llvm.intr.assume %true : i1
58+
%16 = tt.get_program_id x : i32
59+
%17 = tt.get_program_id y : i32
60+
%18 = tt.get_program_id z : i32
61+
%19 = arith.muli %16, %c256_i32 : i32
62+
%20 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
63+
%21 = tt.splat %19 : i32 -> tensor<256xi32>
64+
%22 = arith.addi %21, %20 : tensor<256xi32>
65+
%23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
66+
%24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
67+
%25 = arith.muli %18, %arg5 : i32
68+
%26 = tt.addptr %arg0, %25 : !tt.ptr<f16>, i32
69+
%27 = arith.muli %17, %arg6 : i32
70+
%28 = tt.addptr %26, %27 : !tt.ptr<f16>, i32
71+
%29 = tt.expand_dims %22 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32>
72+
%30 = tt.splat %arg7 : i32 -> tensor<256x1xi32>
73+
%31 = arith.muli %29, %30 : tensor<256x1xi32>
74+
%32 = tt.splat %28 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>>
75+
%33 = tt.addptr %32, %31 : tensor<256x1x!tt.ptr<f16>>, tensor<256x1xi32>
76+
%34 = tt.expand_dims %24 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32>
77+
%35 = tt.broadcast %33 : tensor<256x1x!tt.ptr<f16>> -> tensor<256x128x!tt.ptr<f16>>
78+
%36 = tt.broadcast %34 : tensor<1x128xi32> -> tensor<256x128xi32>
79+
%37 = tt.addptr %35, %36 : tensor<256x128x!tt.ptr<f16>>, tensor<256x128xi32>
80+
%38 = arith.muli %18, %arg8 : i32
81+
%39 = tt.addptr %arg1, %38 : !tt.ptr<f16>, i32
82+
%40 = arith.muli %17, %arg9 : i32
83+
%41 = tt.addptr %39, %40 : !tt.ptr<f16>, i32
84+
%42 = tt.expand_dims %24 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32>
85+
%43 = tt.splat %41 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>>
86+
%44 = tt.addptr %43, %42 : tensor<128x1x!tt.ptr<f16>>, tensor<128x1xi32>
87+
%45 = tt.expand_dims %23 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
88+
%46 = tt.splat %arg10 : i32 -> tensor<1x64xi32>
89+
%47 = arith.muli %45, %46 : tensor<1x64xi32>
90+
%48 = tt.broadcast %44 : tensor<128x1x!tt.ptr<f16>> -> tensor<128x64x!tt.ptr<f16>>
91+
%49 = tt.broadcast %47 : tensor<1x64xi32> -> tensor<128x64xi32>
92+
%50 = tt.addptr %48, %49 : tensor<128x64x!tt.ptr<f16>>, tensor<128x64xi32>
93+
%51 = arith.muli %18, %arg11 : i32
94+
%52 = tt.addptr %arg2, %51 : !tt.ptr<f16>, i32
95+
%53 = arith.muli %17, %arg12 : i32
96+
%54 = tt.addptr %52, %53 : !tt.ptr<f16>, i32
97+
%55 = tt.expand_dims %23 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
98+
%56 = tt.splat %arg13 : i32 -> tensor<64x1xi32>
99+
%57 = arith.muli %55, %56 : tensor<64x1xi32>
100+
%58 = tt.splat %54 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>>
101+
%59 = tt.addptr %58, %57 : tensor<64x1x!tt.ptr<f16>>, tensor<64x1xi32>
102+
%60 = tt.broadcast %59 : tensor<64x1x!tt.ptr<f16>> -> tensor<64x128x!tt.ptr<f16>>
103+
%61 = tt.broadcast %34 : tensor<1x128xi32> -> tensor<64x128xi32>
104+
%62 = tt.addptr %60, %61 : tensor<64x128x!tt.ptr<f16>>, tensor<64x128xi32>
105+
%63 = arith.cmpi slt, %29, %cst_6 : tensor<256x1xi32>
106+
%64 = tt.broadcast %63 : tensor<256x1xi1> -> tensor<256x128xi1>
107+
%65 = arith.muli %arg10, %c64_i32 : i32
108+
%66 = tt.splat %65 : i32 -> tensor<128x64xi32>
109+
%67 = arith.muli %arg13, %c64_i32 : i32
110+
%68 = tt.splat %67 : i32 -> tensor<64x128xi32>
111+
%69 = arith.addi %16, %c1_i32 : i32
112+
%70 = arith.muli %69, %c256_i32 : i32
113+
%71 = arith.muli %18, %c786432_i32 : i32
114+
%72 = tt.addptr %arg3, %71 : !tt.ptr<f32>, i32
115+
%73 = arith.muli %17, %c16384_i32 : i32
116+
%74 = tt.addptr %72, %73 : !tt.ptr<f32>, i32
117+
%75 = tt.splat %74 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>>
118+
%76 = tt.addptr %75, %22 : tensor<256x!tt.ptr<f32>>, tensor<256xi32>
119+
%77 = arith.subi %70, %c16384_i32 : i32
120+
%78 = arith.cmpi sgt, %77, %c0_i32 : i32
121+
%79 = arith.muli %18, %arg14 : i32
122+
%80 = tt.addptr %arg4, %79 : !tt.ptr<f16>, i32
123+
%81 = arith.muli %17, %arg15 : i32
124+
%82 = tt.addptr %80, %81 : !tt.ptr<f16>, i32
125+
%83 = tt.splat %arg16 : i32 -> tensor<256x1xi32>
126+
%84 = arith.muli %29, %83 : tensor<256x1xi32>
127+
%85 = tt.splat %82 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>>
128+
%86 = tt.addptr %85, %84 : tensor<256x1x!tt.ptr<f16>>, tensor<256x1xi32>
129+
%87 = tt.broadcast %86 : tensor<256x1x!tt.ptr<f16>> -> tensor<256x128x!tt.ptr<f16>>
130+
%88 = tt.addptr %87, %36 : tensor<256x128x!tt.ptr<f16>>, tensor<256x128xi32>
131+
%89 = scf.if %78 -> (tensor<256x128xi1>) {
132+
scf.yield %64 : tensor<256x128xi1>
133+
} else {
134+
scf.yield %cst_4 : tensor<256x128xi1>
135+
}
136+
scf.while (%arg27 = %c0_i32) : (i32) -> () {
137+
%90 = arith.cmpi slt, %arg27, %c1_i32 : i32
138+
scf.condition(%90)
139+
} do {
140+
%90 = tt.load %37, %64, %cst_3 : tensor<256x128x!tt.ptr<f16>>
141+
%91:5 = scf.for %arg27 = %c0_i32 to %c8192_i32 step %c64_i32 iter_args(%arg28 = %cst_2, %arg29 = %cst_7, %arg30 = %cst_8, %arg31 = %50, %arg32 = %62) -> (tensor<256x128xf32>, tensor<256xf32>, tensor<256xf32>, tensor<128x64x!tt.ptr<f16>>, tensor<64x128x!tt.ptr<f16>>) : i32 {
142+
%97 = tt.load %arg31 : tensor<128x64x!tt.ptr<f16>>
143+
%98 = tt.dot %90, %97, %cst : tensor<256x128xf16> * tensor<128x64xf16> -> tensor<256x64xf32>
144+
%99 = "tt.reduce"(%98) <{axis = 1 : i32}> ({
145+
^bb0(%arg33: f32, %arg34: f32):
146+
%121 = arith.maxnumf %arg33, %arg34 : f32
147+
tt.reduce.return %121 : f32
148+
}) : (tensor<256x64xf32>) -> tensor<256xf32>
149+
%100 = arith.maxnumf %arg30, %99 : tensor<256xf32>
150+
%101 = arith.mulf %100, %cst_0 : tensor<256xf32>
151+
%102 = arith.mulf %98, %cst_1 : tensor<256x64xf32>
152+
%103 = tt.expand_dims %101 {axis = 1 : i32} : tensor<256xf32> -> tensor<256x1xf32>
153+
%104 = tt.broadcast %103 : tensor<256x1xf32> -> tensor<256x64xf32>
154+
%105 = arith.subf %102, %104 : tensor<256x64xf32>
155+
%106 = math.exp2 %105 : tensor<256x64xf32>
156+
%107 = "tt.reduce"(%106) <{axis = 1 : i32}> ({
157+
^bb0(%arg33: f32, %arg34: f32):
158+
%121 = arith.addf %arg33, %arg34 : f32
159+
tt.reduce.return %121 : f32
160+
}) : (tensor<256x64xf32>) -> tensor<256xf32>
161+
%108 = arith.mulf %arg30, %cst_0 : tensor<256xf32>
162+
%109 = arith.subf %108, %101 : tensor<256xf32>
163+
%110 = math.exp2 %109 : tensor<256xf32>
164+
%111 = tt.expand_dims %110 {axis = 1 : i32} : tensor<256xf32> -> tensor<256x1xf32>
165+
%112 = tt.broadcast %111 : tensor<256x1xf32> -> tensor<256x128xf32>
166+
%113 = arith.mulf %arg28, %112 : tensor<256x128xf32>
167+
%114 = tt.load %arg32 : tensor<64x128x!tt.ptr<f16>>
168+
%115 = arith.mulf %arg29, %110 : tensor<256xf32>
169+
%116 = arith.addf %115, %107 : tensor<256xf32>
170+
%117 = arith.truncf %106 : tensor<256x64xf32> to tensor<256x64xf16>
171+
%118 = tt.dot %117, %114, %113 : tensor<256x64xf16> * tensor<64x128xf16> -> tensor<256x128xf32>
172+
%119 = tt.addptr %arg31, %66 : tensor<128x64x!tt.ptr<f16>>, tensor<128x64xi32>
173+
%120 = tt.addptr %arg32, %68 : tensor<64x128x!tt.ptr<f16>>, tensor<64x128xi32>
174+
scf.yield %118, %116, %100, %119, %120 : tensor<256x128xf32>, tensor<256xf32>, tensor<256xf32>, tensor<128x64x!tt.ptr<f16>>, tensor<64x128x!tt.ptr<f16>>
175+
}
176+
gpu.barrier
177+
%92 = tt.expand_dims %91#1 {axis = 1 : i32} : tensor<256xf32> -> tensor<256x1xf32>
178+
%93 = arith.divf %cst_5, %92 : tensor<256x1xf32>
179+
%94 = tt.broadcast %93 : tensor<256x1xf32> -> tensor<256x128xf32>
180+
%95 = arith.mulf %91#0, %94 : tensor<256x128xf32>
181+
%96 = arith.truncf %95 : tensor<256x128xf32> to tensor<256x128xf16>
182+
scf.if %78 {
183+
%97 = arith.subi %c16640_i32, %70 : i32
184+
%98 = tt.splat %97 : i32 -> tensor<256xi32>
185+
%99 = arith.cmpi slt, %20, %98 : tensor<256xi32>
186+
%100 = math.log2 %91#1 : tensor<256xf32>
187+
%101 = arith.addf %91#2, %100 : tensor<256xf32>
188+
tt.store %76, %101, %99 : tensor<256x!tt.ptr<f32>>
189+
} else {
190+
%97 = math.log2 %91#1 : tensor<256xf32>
191+
%98 = arith.addf %91#2, %97 : tensor<256xf32>
192+
tt.store %76, %98 : tensor<256x!tt.ptr<f32>>
193+
}
194+
tt.store %88, %96, %89 : tensor<256x128x!tt.ptr<f16>>
195+
scf.yield %c1_i32 : i32
196+
}
197+
tt.return
198+
}
199+
}

0 commit comments

Comments
 (0)