Skip to content

Commit 5e59bdf

Browse files
authored
Implement dot_scaled(mmav3) (#5269)
As per title
1 parent dbebe10 commit 5e59bdf

File tree

8 files changed

+158
-103
lines changed

8 files changed

+158
-103
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ LinearLayout ensureLayoutNotSmallerThan(
214214
const LinearLayout &layout,
215215
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);
216216

217+
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);
218+
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
219+
ArrayRef<unsigned> order);
220+
217221
// Dump information about which threads/registers contain each of the tensor
218222
// elements.
219223
void dumpLayout(RankedTensorType tensorType);

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
9191
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
9292
if (srcBlocked && dstDotOp) {
9393
auto dotParent = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent());
94-
if (dotParent && dotParent.isAmpere()) {
94+
if (dotParent) {
9595
return;
9696
}
9797
Attribute sharedMemorySpace =

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,36 @@ LinearLayout ensureLayoutNotSmallerThan(
642642
return ret;
643643
}
644644

645+
// Returns ["dim0", "dim1", ..., "dim<rank-1>"].
646+
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
647+
SmallVector<StringAttr> ret;
648+
for (int i = 0; i < rank; i++) {
649+
ret.push_back(StringAttr::get(ctx, "dim" + llvm::Twine(i)));
650+
}
651+
return ret;
652+
}
653+
654+
// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
655+
// creating a 1D -> 1D mapping of size product(shape) and then reshaping to
656+
// permute(shape, order).
657+
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
658+
ArrayRef<unsigned> order) {
659+
assert(shape.size() == order.size());
660+
MLIRContext *ctx = inDimName.getContext();
661+
auto rank = shape.size();
662+
663+
// The order in triton is written wrt. [dim0, dim1, ...].
664+
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
665+
666+
LinearLayout ret = LinearLayout::empty();
667+
for (int i = 0; i < shape.size(); i++) {
668+
// Start with the most-minor dimension, which is order[0].
669+
int dim = order[i];
670+
ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]);
671+
}
672+
return ret;
673+
}
674+
645675
} // namespace gpu
646676
} // namespace triton
647677
} // namespace mlir

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,6 @@ namespace {
3232

3333
#define S(v) StringAttr::get(ctx, (v))
3434

35-
// Returns ["dim0", "dim1", ..., "dim<rank-1>"].
36-
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
37-
SmallVector<StringAttr> ret;
38-
for (int i = 0; i < rank; i++) {
39-
ret.push_back(S("dim" + llvm::Twine(i)));
40-
}
41-
return ret;
42-
}
43-
4435
// TODO Have order be a mandatory argument of standardOutDimNames.
4536
SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
4637
const SmallVector<unsigned> &order) {
@@ -52,27 +43,6 @@ SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
5243
return ret;
5344
}
5445

55-
// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
56-
// creating a 1D -> 1D mapping of size product(shape) and then reshaping to
57-
// permute(shape, order).
58-
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
59-
ArrayRef<unsigned> order) {
60-
assert(shape.size() == order.size());
61-
MLIRContext *ctx = inDimName.getContext();
62-
auto rank = shape.size();
63-
64-
// The order in triton is written wrt. [dim0, dim1, ...].
65-
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
66-
67-
LinearLayout ret = LinearLayout::empty();
68-
for (int i = 0; i < shape.size(); i++) {
69-
// Start with the most-minor dimension, which is order[0].
70-
int dim = order[i];
71-
ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]);
72-
}
73-
return ret;
74-
}
75-
7646
// Make a LinearLayout that maps a block-id to an N-dimensional index.
7747
//
7848
// The tensor is split up into CTAsPerCGA pieces, which are distributed among

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1313
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
1414
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
15+
#include "triton/Tools/StrUtil.h"
1516
#include "llvm/ADT/ArrayRef.h"
1617
#include "llvm/ADT/SmallVector.h"
1718

@@ -394,6 +395,10 @@ class DecomposeScaledBlocked
394395
auto aType = scaledDotOp.getLhsType();
395396
auto bType = scaledDotOp.getRhsType();
396397

398+
auto rank = oldRetType.getShape().size();
399+
if (rank != 2)
400+
return rewriter.notifyMatchFailure(scaledDotOp, "NYI: rank==3");
401+
397402
assert((aType == ScaleDotElemType::E4M3 ||
398403
aType == ScaleDotElemType::E5M2 ||
399404
aType == ScaleDotElemType::E2M1) &&
@@ -430,71 +435,95 @@ class DecomposeScaledBlocked
430435
// `bases[warps] = {(0, 0), (0, 0), ...}`
431436

432437
auto newAEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaEnc, aKWidth);
433-
auto rank = mmaEnc.getInstrShape().size();
438+
434439
// MMAv3 uses the first dimension for the M dimension, while MMAv2 uses the
435440
// penultimate (ugh)
436-
auto instrShapeM = mmaEnc.getInstrShape()[versionMajor == 3 ? 0 : rank - 2];
441+
auto instrShapeM =
442+
mmaEnc.getInstrShape()[versionMajor == 3
443+
? 0
444+
: mmaEnc.getInstrShape().size() - 2];
437445
auto warpSize = getWarpSize(newAEncoding);
438446
assert(instrShapeM <= warpSize);
439447
// Necessary choice to leave all the scales of the tile in that given warp
440448
auto threadsPerWarp =
441449
SmallVector<unsigned>{instrShapeM, warpSize / instrShapeM};
442450

443-
assert(versionMajor == 2 &&
444-
"NYI: MMAv3. Need to rethink the scale layout otherwise");
445-
446-
// Copy the bases
447-
451+
// This has to align with the order in UpcastMXFPOp
452+
auto order = getMatrixOrder(rank, /*rowMajor=*/true);
448453
Attribute newScaleEncoding = triton::gpu::BlockedEncodingAttr::get(
449-
ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(),
450-
newAEncoding.getCTAOrder(), mmaEnc.getCTALayout());
454+
ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), order,
455+
mmaEnc.getCTALayout());
451456

457+
// Lezcano: In the future we could just use the LLs unconditionally
458+
// Not doing it now as they are not as performant as Blocked encoding at
459+
// times E.g., we bail on them in the backwardMaterialization pass
452460
auto dotBroadcastsWarpLevel = mmaEnc.getWarpsPerCTA()[1] != 1;
453461
if (dotBroadcastsWarpLevel) {
454-
// If mma has warpsPerCTA == {2, 2}, then newAEncoding has
455-
// warpsPerCTA == {2, 1}. In this case, we need to broadcast the warps
456-
// on the second dimension as per
457-
// A: 0 1 | 0 1
458-
// - - | - -
459-
// 2 3 | 2 3
460-
// This broadcasting is not representable by standard blocked encodings,
461-
// so we need to use linear layouts.
462-
// This broadcasting is implemented in ampereDotToLinearLayout
463-
auto blocked = cast<BlockedEncodingAttr>(newScaleEncoding);
464-
auto blockedLL = *blocked.toLinearLayout(a.getType().getShape());
465-
LinearLayout::BasesT scaleBases = blockedLL.getBases();
466-
auto nBases = llvm::Log2_32(mmaEnc.getWarpsPerCTA()[1]);
467-
auto &warps = scaleBases[StringAttr::get(ctx, "warp")];
468-
// Prepend the vector of zeros to the warpBases
469-
warps.insert(warps.begin(), nBases, std::vector<int32_t>(rank, 0));
470-
auto outDims = llvm::to_vector(blockedLL.getOutDimNames());
471-
auto newLL = LinearLayout(scaleBases, outDims);
472-
auto llEncoding = LinearEncodingAttr::get(ctx, std::move(newLL));
473-
// Adjust the shape of the layout to match the scale operand
474-
auto scaleShape = scale.getType().getShape();
475-
newScaleEncoding =
476-
LinearEncodingAttr::get(ctx, *llEncoding.toLinearLayout(scaleShape));
462+
auto kRegister = StringAttr::get(ctx, "register");
463+
auto regs = identityStandardND(kRegister, {1, 1}, order);
464+
auto lanes =
465+
identityStandardND(StringAttr::get(ctx, "lane"), {16, 2}, order);
466+
467+
// Extract warp layout from dotAEncoding
468+
// In the future we'll have some nice division utils, but until then...
469+
auto dotLL = *newAEncoding.toLinearLayout(a.getType().getShape());
470+
LinearLayout::BasesT scaleBases = dotLL.getBases();
471+
auto kWarp = StringAttr::get(ctx, "warp");
472+
auto &warpBases = scaleBases[kWarp];
473+
// The tile shape was [16, 2 * 4 * kWidth] with broadcasting in K
474+
// We divide the M dimension by 16
475+
auto div = 16;
476+
for (auto &warpBase : warpBases) {
477+
if (warpBase[rank - 2] != 0) {
478+
assert(warpBase[rank - 2] % div == 0);
479+
warpBase[rank - 2] /= div;
480+
}
481+
}
482+
483+
LinearLayout::BasesT warpBlockBases;
484+
auto standardOutDims = llvm::to_vector(dotLL.getOutDimNames());
485+
warpBlockBases[kWarp] = warpBases;
486+
auto kBlock = StringAttr::get(ctx, "block");
487+
assert(scaleBases[kBlock].empty() && "NYI: CGAs");
488+
warpBlockBases[kBlock] = {};
489+
auto warpBlock = LinearLayout(std::move(warpBlockBases), standardOutDims);
490+
491+
auto newLL =
492+
(regs * lanes) *
493+
warpBlock.transposeOuts(llvm::to_vector(lanes.getOutDimNames()));
494+
auto shape = scale.getType().getShape();
495+
496+
// Broadcast to the correct shape Equivalent to
497+
// newLL = ensureLayoutNotSmallerThan(newLL.transposeOuts(getRepOrder),
498+
// shape);
499+
for (auto d : newAEncoding.getRepOrder()) {
500+
auto outDim = standardOutDims[d];
501+
auto dimSize = newLL.getOutDimSize(outDim);
502+
newLL *=
503+
LinearLayout::identity1D(shape[d] / dimSize, kRegister, outDim);
504+
}
505+
newLL = newLL.transposeOuts(standardOutDims);
506+
newScaleEncoding = LinearEncodingAttr::get(ctx, std::move(newLL));
477507
}
478508

479509
a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding);
480510

481-
// Upcast B operand
482-
assert(bType != ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4");
483-
auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, bKWidth);
484-
b = createArg(rewriter, b, 1, bType, newBEncoding,
485-
/*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt);
486511
Operation *newDot = nullptr;
487512
if (versionMajor == 2) {
513+
// Upcast B operand
514+
assert(bType != ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4");
515+
auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, bKWidth);
516+
b = createArg(rewriter, b, 1, bType, newBEncoding,
517+
/*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt);
488518
newDot = rewriter.create<DotOp>(scaledDotOp.getLoc(), newRetType, a, b,
489519
newAcc);
490520
} else {
491521
assert(versionMajor == 3);
492522
// At the time of this writing, this is always true
493523
auto allowTranspose = b.getType().getElementType().isBF16();
494-
b = cast<TypedValue<RankedTensorType>>(
495-
getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose));
524+
auto bShmem = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose);
496525
newDot = rewriter.create<triton::nvidia_gpu::WarpGroupDotOp>(
497-
scaledDotOp.getLoc(), newRetType, a, b, newAcc, nullptr);
526+
scaledDotOp.getLoc(), newRetType, a, bShmem, newAcc, nullptr);
498527
}
499528

500529
// convert dot instruction
@@ -578,11 +607,11 @@ class DecomposeScaledBlocked
578607
auto dotOp = rewriter.create<DotOp>(
579608
scaledDotOp.getLoc(), scaledDotOp.getType(), a, b, scaledDotOp.getC());
580609

581-
// Waiting for https://github.com/triton-lang/triton/pull/5003 to land
582-
// cf.
583-
// https://github.com/triton-lang/triton/pull/5003#issuecomment-2445091746
584-
// int versionMajor = getMMAVersionSafe(computeCapability, dotOp);
585610
int versionMajor = 2;
611+
// We just support bf16 for MMAv3 on the rhs
612+
if (bType == ScaleDotElemType::BF16) {
613+
versionMajor = getMMAVersionSafe(computeCapability, dotOp);
614+
}
586615
int versionMinor = computeCapability == 75 ? 1 : 0;
587616

588617
RankedTensorType oldRetType = dotOp.getType();

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
164164

165165
// -----
166166

167-
// Verify that dot_scaled (mxfp4 x bf16) decomposes as expected
167+
// Verify that dot_scaled (mxfp4 x {bf16,fp8}) decomposes to mmav3 if it's bf16, otherwise it fallsback to mmav2
168168
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
169169
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
170170
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
@@ -174,13 +174,28 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
174174
tt.func @dot_scaled(
175175
%a: tensor<128x32xi8, #blocked2>,
176176
%scale: tensor<128x2xi8, #blocked1>,
177-
%b: tensor<64x128xbf16, #blocked>)
178-
-> tensor<128x128xf32, #blocked> {
177+
%b_bf16: tensor<64x128xbf16, #blocked>
178+
) -> tensor<128x128xf32, #blocked> {
179+
// CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, {{.*}}>
180+
// CHECK: triton_gpu.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{{.*}}>>, tensor<128x2xi8, {{.*}}> -> tensor<128x64xbf16, #triton_gpu.dot_op<{{.*}}>>
181+
// CHECK: triton_nvidia_gpu.warp_group_dot
182+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
183+
%result = tt.dot_scaled %a scale %scale, %b_bf16, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
184+
tt.return %result : tensor<128x128xf32, #blocked>
185+
}
186+
187+
// Verify that dot_scaled (mxfp4 x fp8) decomposes into mmav2
188+
// CHECK: dot_scaled_fp8
189+
tt.func @dot_scaled_fp8(
190+
%a: tensor<128x32xi8, #blocked2>,
191+
%scale: tensor<128x2xi8, #blocked1>,
192+
%b_fp8: tensor<64x128xf8E4M3FN, #blocked>
193+
) -> tensor<128x128xf32, #blocked> {
194+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
179195
// CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, #[[LINEAR]]>
180196
// CHECK: triton_gpu.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #triton_gpu.dot_op<{{.*}}>>
181197
// CHECK: tt.dot
182-
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
183-
%result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked>
198+
%result = tt.dot_scaled %a scale %scale, %b_fp8, %cst lhs = e2m1 rhs = e4m3 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked>
184199
tt.return %result : tensor<128x128xf32, #blocked>
185200
}
186201
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,7 @@ struct DecomposeUnsupportedConversions
7474
// Remove the decomposeTensorCoreToDotLayoutConversion class entirely after
7575
// we have enabled the new layout conversion for all the cases.
7676
auto nvidiaShortCutFn = [&](RankedTensorType srcTy,
77-
RankedTensorType dstTy) {
78-
auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
79-
// Supported mma to dot conversion
80-
if (nvidiaMma && nvidiaMma.isAmpere())
81-
return true;
82-
// No need to decompose if shared memory is not needed
83-
return matchMmaV3AndDotOperandLayout(srcTy, dstTy) ||
84-
cvtReordersRegisters(srcTy, dstTy);
85-
};
77+
RankedTensorType dstTy) { return true; };
8678
ModuleOp mod = getOperation();
8779
triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod);
8880
triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod,

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,28 +49,43 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
4949
Value warpId = udiv(tid, warpSize);
5050
Value laneId = urem(tid, warpSize);
5151

52+
auto kWidth =
53+
cast<DotOperandEncodingAttr>(op.getType().getEncoding()).getKWidth();
54+
5255
if (fpType == ScaleDotElemType::E2M1)
5356
xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals);
5457

5558
// Each thread owns elements of 4 mxfp vectors so we need 4 scales
56-
// Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c +
57-
// 16, c + 17
59+
// Since we go from a threadShape of 8x4 to 16x2, we let c = tid / 4 * 2
60+
// Then, we need elements c and c + 16 for the first two mxfp vectors
61+
// and elements c + 1 and c + 17 for the last two mxfp vectors
5862
auto c = mul(udiv(laneId, i32_val(4)), i32_val(2));
59-
std::array<Value, 4> ci = {c, add(c, i32_val(1)), add(c, i32_val(16)),
63+
std::array<Value, 4> ci = {c, add(c, i32_val(16)), add(c, i32_val(1)),
6064
add(c, i32_val(17))};
6165

66+
// TODO Move this logic to using LinearLayouts
67+
// Each scale in a warp has to be replicated to cover a tile of shape mxk =
68+
// 16x64 This 16x64 tile is split into 4 subtiles of shape 8x32, each of
69+
// which will have to gather a scale and multiply its relevant part of the
70+
// mxfp vector This tile of 8x32 is split in to 8x4 vectors, leaving each
71+
// vector with 1x8 mxfp elements as long as kWidth * 4 <= 32
72+
assert(kWidth <= 8 &&
73+
"NYI for larger kWidth (but we could do it with less shuffles!)");
6274
for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) {
63-
// column major as per the DotOperandEncoding(opidx=0) layout
64-
auto si = std::array<Value, 4>{
65-
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[0]),
66-
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[2]),
67-
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[1]),
68-
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[3]),
69-
};
70-
71-
for (int j = 0; j < 32; ++j) {
72-
xVals[32 * i + j] =
73-
LLVM::mxfpScaleBf16(rewriter, loc, xVals[32 * i + j], si[j / 8]);
75+
for (int mxfp = 0; mxfp < 2; ++mxfp) {
76+
auto si = std::array<Value, 2>{
77+
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[mxfp * 2 + 0]),
78+
targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[mxfp * 2 + 1])};
79+
for (int rep = 0; rep < 8 / kWidth; ++rep) {
80+
for (int subTile = 0; subTile < 2; ++subTile) {
81+
for (int k = 0; k < kWidth; ++k) {
82+
auto idx =
83+
32 * i + 16 * mxfp + rep * 2 * kWidth + subTile * kWidth + k;
84+
xVals[idx] =
85+
LLVM::mxfpScaleBf16(rewriter, loc, xVals[idx], si[subTile]);
86+
}
87+
}
88+
}
7489
}
7590
}
7691

0 commit comments

Comments
 (0)