Skip to content

Commit 0a93c96

Browse files
plognjenoplavsic
andauthored
[AMD] Add bypassLDS feature to StreamPipeline (triton-lang#7968)
Determine if it is safe to bypass LDS for dot operands. Normally, dot operation operands are consumed in the dot MFMA layout, which is not coalesced. To better utilize global memory bandwidth, operands are usually loaded in a coalesced "blocked" layout and then rearranged through LDS. However, certain optimizations allow dot operands to be preshuffled in global memory. In that case, the operands can be loaded efficiently (in a coalesced way) and consumed directly by the dot operation. When preshuffling is used, a sequence of transpose and reshape ops must be applied to the operand. To verify that preshuffling was done correctly and the final layout remains coalesced, we start from the dot MFMA layout and apply the inverse of each transpose/reshape op (while ignoring convert_layout ops) until we reach the load. We then inspect the resulting layout to decide if it is coalesced enough to load directly, without needing any further rearrangement. TODO: getContigPerThread does not work if elements are permuted within thread. We need to use some utility similar to largestVectorisation() to detect this once load op vectorization supports in thread permutations as well. --------- Co-authored-by: Ognjen Plavsic <[email protected]>
1 parent e174882 commit 0a93c96

File tree

5 files changed

+336
-49
lines changed

5 files changed

+336
-49
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ std::optional<StringRef> getAMDArch(Operation *module);
213213
std::optional<mlir::triton::gpu::SwizzledSharedEncodingAttr>
214214
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);
215215

216-
// Convert \param op operands and results to layout \param encoding.
217-
void convertOpEncoding(Attribute encoding, Operation *op);
216+
// Convert \param op to use \param encoding attribute.
217+
// Skips operands if they're in shared encoding.
218+
Operation *convertDistributedOpEncoding(Attribute encoding, Operation *op);
218219

219220
// Returns the original memory allocation for a memdesc value
220221
triton::gpu::LocalAllocOp findShmemAlloc(Value operand);

lib/Dialect/TritonGPU/Transforms/Coalesce.cpp

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -109,49 +109,6 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
109109
return tensorType.cloneWithEncoding(encoding);
110110
}
111111

112-
void coalesceOp(Attribute encoding, Operation *op) {
113-
OpBuilder builder(op);
114-
// Convert operands
115-
// For load/store with tensor pointers, we don't have to change the
116-
// operands' type, we do this by changing the outputs' type of
117-
// `make_tensor_ptr`
118-
SmallVector<Value, 4> newArgs;
119-
for (auto operand : op->getOperands()) {
120-
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
121-
if (tensorType &&
122-
!isa<triton::gpu::SharedEncodingTrait>(tensorType.getEncoding())) {
123-
Type newType = getNewType(tensorType, encoding);
124-
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
125-
op->getLoc(), newType, operand));
126-
} else {
127-
newArgs.push_back(operand);
128-
}
129-
}
130-
131-
// Convert output types
132-
SmallVector<Type, 4> newTypes;
133-
for (auto t : op->getResultTypes()) {
134-
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
135-
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
136-
}
137-
138-
// Construct new op with the new encoding
139-
Operation *newOp =
140-
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
141-
newTypes, op->getAttrs());
142-
143-
// Cast the results back to the original layout
144-
for (size_t i = 0; i < op->getNumResults(); i++) {
145-
Value newResult = newOp->getResult(i);
146-
if (newTypes[i] != op->getResultTypes()[i]) {
147-
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
148-
op->getLoc(), op->getResult(i).getType(), newResult);
149-
}
150-
op->getResult(i).replaceAllUsesWith(newResult);
151-
}
152-
op->erase();
153-
}
154-
155112
void runOnOperation() override {
156113
// Run axis info analysis
157114
ModuleOp moduleOp = getOperation();
@@ -184,7 +141,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
184141
// 4. Convert the output of this new memory op back to L1
185142
// 5. Replace all the uses of the original memory op by the new one
186143
for (auto &kv : layoutMap) {
187-
coalesceOp(kv.second, kv.first);
144+
convertDistributedOpEncoding(kv.second, kv.first);
188145
}
189146
}
190147
};

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,55 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
11681168
return attr;
11691169
}
11701170

1171+
static Type getNewType(Type type, Attribute encoding) {
1172+
RankedTensorType tensorType = cast<RankedTensorType>(type);
1173+
return RankedTensorType::get(tensorType.getShape(),
1174+
tensorType.getElementType(), encoding);
1175+
}
1176+
1177+
Operation *convertDistributedOpEncoding(Attribute encoding, Operation *op) {
1178+
OpBuilder builder(op);
1179+
// Convert operands
1180+
// For load/store with tensor pointers, we don't have to change the
1181+
// operands' type, we do this by changing the outputs' type of
1182+
// `make_tensor_ptr`
1183+
SmallVector<Value, 4> newArgs;
1184+
for (auto operand : op->getOperands()) {
1185+
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1186+
if (tensorType &&
1187+
!isa<triton::gpu::SharedEncodingTrait>(tensorType.getEncoding())) {
1188+
Type newType = getNewType(tensorType, encoding);
1189+
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
1190+
op->getLoc(), newType, operand));
1191+
} else {
1192+
newArgs.push_back(operand);
1193+
}
1194+
}
1195+
1196+
// Convert output types
1197+
SmallVector<Type, 4> newTypes;
1198+
for (auto t : op->getResultTypes()) {
1199+
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
1200+
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
1201+
}
1202+
1203+
// Construct new op with the new encoding
1204+
Operation *newOp = builder.create(op->getLoc(), op->getName().getIdentifier(),
1205+
newArgs, newTypes, op->getAttrs());
1206+
1207+
// Cast the results back to the original layout
1208+
for (size_t i = 0; i < op->getNumResults(); i++) {
1209+
Value newResult = newOp->getResult(i);
1210+
if (newTypes[i] != op->getResultTypes()[i]) {
1211+
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
1212+
op->getLoc(), op->getResult(i).getType(), newResult);
1213+
}
1214+
op->getResult(i).replaceAllUsesWith(newResult);
1215+
}
1216+
op->erase();
1217+
return newOp;
1218+
}
1219+
11711220
namespace {
11721221

11731222
/// Detect dead arguments in scf.for op by assuming all the values are dead and

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,3 +917,108 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
917917
tt.return
918918
}
919919
}
920+
921+
// -----
922+
923+
// COMMON-LABEL: bypass_lds_b_operand
924+
925+
// SYNC: scf.for
926+
// SYNC: %[[load:.+]] = tt.load {{.*}} : tensor<8x2048x!tt.ptr<i8>, #linear>
927+
// SYNC: %[[reshape1:.+]] = tt.reshape %arg24
928+
// SYNC: %[[trans1:.+]] = tt.trans %[[reshape1]]
929+
// SYNC: %[[reshape2:.+]] = tt.reshape %[[trans1]]
930+
// SYNC: %[[trans2:.+]] = tt.trans %[[reshape2]] {{.*}} -> tensor<128x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
931+
// SYNC: tt.dot_scaled {{.*}}, %[[trans2]]
932+
// SYNC: scf.yield {{.*}}, %[[load]]
933+
934+
935+
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
936+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
937+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
938+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}>
939+
#linear = #ttg.linear<{register = [[0, 2], [0, 1]], lane = [[0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128]], warp = [[0, 0], [0, 0]], block = []}>
940+
#linear1 = #ttg.linear<{register = [[0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0]], lane = [[0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 2, 0, 0, 0], [0, 0, 0, 4, 0, 0, 0], [0, 0, 0, 8, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 2, 0, 0, 0, 0]], warp = [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]], block = []}>
941+
#linear2 = #ttg.linear<{register = [[0, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 0]], lane = [[0, 0, 1, 0, 0, 0, 0], [0, 0, 2, 0, 0, 0, 0], [0, 0, 4, 0, 0, 0, 0], [0, 0, 8, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 2, 0]], warp = [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]], block = []}>
942+
#linear3 = #ttg.linear<{register = [[0, 4], [16, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[0, 0], [0, 0]], block = []}>
943+
#linear4 = #ttg.linear<{register = [[0, 2], [0, 1]], lane = [[0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128]], warp = [[1, 0], [2, 0]], block = []}>
944+
#linear5 = #ttg.linear<{register = [[0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0]], lane = [[0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 2, 0, 0, 0], [0, 0, 0, 4, 0, 0, 0], [0, 0, 0, 8, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 2, 0, 0, 0, 0]], warp = [[1, 0, 0, 0, 0, 0, 0], [2, 0, 0, 0, 0, 0, 0]], block = []}>
945+
#linear6 = #ttg.linear<{register = [[0, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 0]], lane = [[0, 0, 1, 0, 0, 0, 0], [0, 0, 2, 0, 0, 0, 0], [0, 0, 4, 0, 0, 0, 0], [0, 0, 8, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 2, 0]], warp = [[1, 0, 0, 0, 0, 0, 0], [2, 0, 0, 0, 0, 0, 0]], block = []}>
946+
#linear7 = #ttg.linear<{register = [[0, 4], [16, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[32, 0], [64, 0]], block = []}>
947+
#linear8 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 1024], [1, 0]], lane = [[0, 16], [0, 32], [0, 64], [0, 128], [0, 256], [0, 512]], warp = [[2, 0], [4, 0]], block = []}>
948+
#linear9 = #ttg.linear<{register = [[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 2], [0, 0, 0, 0, 0, 4], [0, 0, 0, 0, 0, 8], [0, 0, 4, 0, 0, 0], [0, 1, 0, 0, 0, 0]], lane = [[0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 2, 0], [0, 0, 0, 0, 4, 0], [0, 0, 0, 0, 8, 0], [0, 0, 1, 0, 0, 0], [0, 0, 2, 0, 0, 0]], warp = [[0, 2, 0, 0, 0, 0], [0, 4, 0, 0, 0, 0]], block = []}>
949+
#linear10 = #ttg.linear<{register = [[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 2], [0, 0, 0, 0, 0, 4], [0, 0, 0, 0, 0, 8], [0, 0, 0, 4, 0, 0], [0, 1, 0, 0, 0, 0]], lane = [[0, 0, 1, 0, 0, 0], [0, 0, 2, 0, 0, 0], [0, 0, 4, 0, 0, 0], [0, 0, 8, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 2, 0, 0]], warp = [[0, 2, 0, 0, 0, 0], [0, 4, 0, 0, 0, 0]], block = []}>
950+
#linear11 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 64], [16, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 16], [0, 32]], warp = [[32, 0], [64, 0]], block = []}>
951+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], tilesPerWarp = [2, 2], instrShape = [16, 16], isTransposed = true}>
952+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
953+
tt.func public @bypass_lds_b_operand(%a_ptr: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %b_ptr: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %c_ptr: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %a_scales_ptr: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %b_scales_ptr: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %K: i32 {tt.divisibility = 16 : i32}, %stride_am: i32 {tt.divisibility = 16 : i32}, %stride_bn: i32 {tt.divisibility = 16 : i32}, %stride_ck: i32 {tt.divisibility = 16 : i32}, %stride_cm: i32 {tt.divisibility = 16 : i32}, %stride_asm: i32 {tt.divisibility = 16 : i32}, %stride_bsn: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
954+
%cst = arith.constant dense<128> : tensor<32x128xi32, #blocked>
955+
%cst_0 = arith.constant dense<2048> : tensor<8x2048xi32, #blocked1>
956+
%cst_1 = arith.constant dense<256> : tensor<4x256xi32, #blocked2>
957+
%c1_i32 = arith.constant 1 : i32
958+
%pid_unified = arith.constant 7 : i32
959+
%c64_i32 = arith.constant 64 : i32
960+
%num_pid_n = arith.constant 127 : i32
961+
%cst_2 = arith.constant dense<256> : tensor<1x256xi32, #blocked3>
962+
%c128_i32 = arith.constant 128 : i32
963+
%c0_i32 = arith.constant 0 : i32
964+
%c32_i32 = arith.constant 32 : i32
965+
%c8_i32 = arith.constant 8 : i32
966+
%c4_i32 = arith.constant 4 : i32
967+
%cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #mma>
968+
%pid_unified_4 = tt.get_program_id x : i32
969+
%xcd = arith.remsi %pid_unified_4, %c8_i32 : i32
970+
%local_pid = arith.divsi %pid_unified_4, %c8_i32 : i32
971+
%pid = arith.muli %xcd, %c8_i32 : i32
972+
%pid_9 = arith.addi %pid, %local_pid : i32
973+
%num_pid_n_7 = arith.addi %N, %num_pid_n : i32
974+
%num_pid_n_8 = arith.divsi %num_pid_n_7, %c128_i32 : i32
975+
%pid_n = arith.remsi %pid_9, %num_pid_n_8 : i32
976+
%offs_bn = arith.muli %pid_n, %c8_i32 : i32
977+
%offs_bn_15 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
978+
%offs_bn_16 = tt.splat %offs_bn : i32 -> tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
979+
%offs_bn_17 = arith.addi %offs_bn_16, %offs_bn_15 : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
980+
%offs_bn_18 = tt.splat %N : i32 -> tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
981+
%offs_bn_19 = arith.remsi %offs_bn_17, %offs_bn_18 : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
982+
%a_ptrs_28 = tt.splat %a_ptr : !tt.ptr<i8> -> tensor<32x128x!tt.ptr<i8>, #blocked>
983+
%b_ptrs = tt.expand_dims %offs_bn_19 {axis = 1 : i32} : tensor<8xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<8x1xi32, #blocked1>
984+
%b_ptrs_29 = tt.splat %stride_bn : i32 -> tensor<8x1xi32, #blocked1>
985+
%b_ptrs_30 = arith.muli %b_ptrs, %b_ptrs_29 : tensor<8x1xi32, #blocked1>
986+
%b_ptrs_31 = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
987+
%b_ptrs_32 = tt.expand_dims %b_ptrs_31 {axis = 0 : i32} : tensor<2048xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x2048xi32, #blocked1>
988+
%b_ptrs_33 = tt.broadcast %b_ptrs_30 : tensor<8x1xi32, #blocked1> -> tensor<8x2048xi32, #blocked1>
989+
%b_ptrs_34 = tt.broadcast %b_ptrs_32 : tensor<1x2048xi32, #blocked1> -> tensor<8x2048xi32, #blocked1>
990+
%b_ptrs_35 = arith.addi %b_ptrs_33, %b_ptrs_34 : tensor<8x2048xi32, #blocked1>
991+
%b_ptrs_36 = tt.splat %b_ptr : !tt.ptr<i8> -> tensor<8x2048x!tt.ptr<i8>, #blocked1>
992+
%b_ptrs_37 = tt.addptr %b_ptrs_36, %b_ptrs_35 : tensor<8x2048x!tt.ptr<i8>, #blocked1>, tensor<8x2048xi32, #blocked1>
993+
%b_scale_ptrs_53 = tt.splat %b_scales_ptr : !tt.ptr<i8> -> tensor<4x256x!tt.ptr<i8>, #blocked2>
994+
%a_scale_ptrs_56 = tt.splat %a_scales_ptr : !tt.ptr<i8> -> tensor<1x256x!tt.ptr<i8>, #blocked3>
995+
%accumulator:5 = scf.for %accumulator_83 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%a_scale_ptrs_84 = %a_scale_ptrs_56, %arg16 = %cst_3, %b_scale_ptrs_85 = %b_scale_ptrs_53, %a_ptrs_86 = %a_ptrs_28, %b_ptrs_87 = %b_ptrs_37) -> (tensor<1x256x!tt.ptr<i8>, #blocked3>, tensor<32x128xf32, #mma>, tensor<4x256x!tt.ptr<i8>, #blocked2>, tensor<32x128x!tt.ptr<i8>, #blocked>, tensor<8x2048x!tt.ptr<i8>, #blocked1>) : i32 {
996+
%a_scales = tt.load %a_scale_ptrs_84 : tensor<1x256x!tt.ptr<i8>, #blocked3>
997+
%a_scales_88 = ttg.convert_layout %a_scales : tensor<1x256xi8, #blocked3> -> tensor<1x256xi8, #linear>
998+
%a_scales_89 = tt.reshape %a_scales_88 : tensor<1x256xi8, #linear> -> tensor<1x1x4x16x2x2x1xi8, #linear1>
999+
%a_scales_90 = tt.trans %a_scales_89 {order = array<i32: 0, 5, 3, 1, 4, 2, 6>} : tensor<1x1x4x16x2x2x1xi8, #linear1> -> tensor<1x2x16x1x2x4x1xi8, #linear2>
1000+
%a_scales_91 = tt.reshape %a_scales_90 : tensor<1x2x16x1x2x4x1xi8, #linear2> -> tensor<32x8xi8, #linear3>
1001+
%b_scales = tt.load %b_scale_ptrs_85 : tensor<4x256x!tt.ptr<i8>, #blocked2>
1002+
%b_scales_92 = ttg.convert_layout %b_scales : tensor<4x256xi8, #blocked2> -> tensor<4x256xi8, #linear4>
1003+
%b_scales_93 = tt.reshape %b_scales_92 : tensor<4x256xi8, #linear4> -> tensor<4x1x4x16x2x2x1xi8, #linear5>
1004+
%b_scales_94 = tt.trans %b_scales_93 {order = array<i32: 0, 5, 3, 1, 4, 2, 6>} : tensor<4x1x4x16x2x2x1xi8, #linear5> -> tensor<4x2x16x1x2x4x1xi8, #linear6>
1005+
%b_scales_95 = tt.reshape %b_scales_94 : tensor<4x2x16x1x2x4x1xi8, #linear6> -> tensor<128x8xi8, #linear7>
1006+
%a = tt.load %a_ptrs_86 : tensor<32x128x!tt.ptr<i8>, #blocked>
1007+
%b = tt.load %b_ptrs_87 : tensor<8x2048x!tt.ptr<i8>, #blocked1>
1008+
%accumulator_96 = ttg.convert_layout %b : tensor<8x2048xi8, #blocked1> -> tensor<8x2048xi8, #linear8>
1009+
%b_97 = tt.reshape %accumulator_96 : tensor<8x2048xi8, #linear8> -> tensor<1x8x8x1x16x16xi8, #linear9>
1010+
%b_98 = tt.trans %b_97 {order = array<i32: 0, 1, 4, 2, 3, 5>} : tensor<1x8x8x1x16x16xi8, #linear9> -> tensor<1x8x16x8x1x16xi8, #linear10>
1011+
%b_99 = tt.reshape %b_98 : tensor<1x8x16x8x1x16xi8, #linear10> -> tensor<128x128xi8, #linear11>
1012+
%b_100 = tt.trans %b_99 {order = array<i32: 1, 0>} : tensor<128x128xi8, #linear11> -> tensor<128x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
1013+
%a_101 = ttg.convert_layout %a : tensor<32x128xi8, #blocked> -> tensor<32x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
1014+
%accumulator_102 = tt.dot_scaled %a_101 scale %a_scales_91, %b_100 scale %b_scales_95, %cst_3 lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<32x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<32x8xi8, #linear3> * tensor<128x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<128x8xi8, #linear7> -> tensor<32x128xf32, #mma>
1015+
%accumulator_103 = arith.addf %arg16, %accumulator_102 : tensor<32x128xf32, #mma>
1016+
%a_ptrs_104 = tt.addptr %a_ptrs_86, %cst : tensor<32x128x!tt.ptr<i8>, #blocked>, tensor<32x128xi32, #blocked>
1017+
%b_ptrs_105 = tt.addptr %b_ptrs_87, %cst_0 : tensor<8x2048x!tt.ptr<i8>, #blocked1>, tensor<8x2048xi32, #blocked1>
1018+
%a_scale_ptrs_106 = tt.addptr %a_scale_ptrs_84, %cst_2 : tensor<1x256x!tt.ptr<i8>, #blocked3>, tensor<1x256xi32, #blocked3>
1019+
%b_scale_ptrs_107 = tt.addptr %b_scale_ptrs_85, %cst_1 : tensor<4x256x!tt.ptr<i8>, #blocked2>, tensor<4x256xi32, #blocked2>
1020+
scf.yield %a_scale_ptrs_106, %accumulator_103, %b_scale_ptrs_107, %a_ptrs_104, %b_ptrs_105 : tensor<1x256x!tt.ptr<i8>, #blocked3>, tensor<32x128xf32, #mma>, tensor<4x256x!tt.ptr<i8>, #blocked2>, tensor<32x128x!tt.ptr<i8>, #blocked>, tensor<8x2048x!tt.ptr<i8>, #blocked1>
1021+
}
1022+
tt.return
1023+
}
1024+
}

0 commit comments

Comments
 (0)