Skip to content

Commit b7eda0b

Browse files
Merge commit '4bcdbdee14a992b6086afb1f3025fe2767fdbbec'
2 parents d26d7ea + 4bcdbde commit b7eda0b

File tree

16 files changed

+444
-115
lines changed

16 files changed

+444
-115
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [
8787

8888
let arguments = (ins
8989
TTG_TensorOrMemDesc:$a,
90-
TTG_TensorOrMemDesc:$b,
90+
TTG_MemDescType:$b,
9191
TT_FpIntTensor:$c,
9292
Optional<I1>:$useC,
9393
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
@@ -99,7 +99,7 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [
9999

100100
let assemblyFormat = [{
101101
$a`,` $b`,` $c (`,` $useC^)? attr-dict
102-
`:` type($a) `*` type($b) `->` type($d)
102+
`:` type($a) `*` qualified(type($b)) `->` type($d)
103103
}];
104104

105105
let extraClassDeclaration = [{

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,17 @@ LogicalResult WarpGroupDotOp::verify() {
6363
auto nvmmaEnc = dyn_cast<NvidiaMmaEncodingAttr>(resTy.getEncoding());
6464
if (!nvmmaEnc || !nvmmaEnc.isHopper())
6565
return emitOpError("WGMMA result layout must be Hopper NVMMA");
66+
67+
if (!isa<NVMMASharedEncodingAttr, DotOperandEncodingAttr>(
68+
getA().getType().getEncoding()))
69+
return emitOpError("WGMMA A operand must have NVMMA shared or dot layout");
70+
if (!isa<NVMMASharedEncodingAttr>(getB().getType().getEncoding()))
71+
return emitOpError("WGMMA B operand must have NVMMA shared layout");
72+
6673
auto numWarps = gpu::lookupNumWarps(getOperation());
6774
if (numWarps % 4)
6875
return emitOpError("WGMMA requires num_warps to be divisible by 4");
76+
6977
auto retShapePerCTA = getShapePerCTA(resTy);
7078
int rank = retShapePerCTA.size();
7179
if (rank != 2)
@@ -74,12 +82,14 @@ LogicalResult WarpGroupDotOp::verify() {
7482
return emitOpError("WGMMA result M dimension must be divisible by 64");
7583
if (retShapePerCTA[1] % 8 != 0)
7684
return emitOpError("WGMMA result N dimension must be divisible by 8");
85+
7786
auto aElemTy = getA().getType().getElementType();
7887
if (!(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
7988
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
8089
aElemTy.isF32()))
8190
return emitOpError("WGMMA result element type must be F16, BF16, F32, "
8291
"F8E5M2, F8E4M3FN, or integer type");
92+
8393
if (getMaxNumImpreciseAcc() < 32 &&
8494
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
8595
resTy.getElementType().isF32()) {

lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp

Lines changed: 90 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
#include "mlir/Analysis/Liveness.h"
2+
#include "mlir/Dialect/Arith/IR/Arith.h"
3+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
24
#include "mlir/Support/LogicalResult.h"
35
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
46
#include "mlir/Transforms/Passes.h"
57
#include "triton/Analysis/Allocation.h"
68
#include "triton/Dialect/Triton/IR/Utility.h"
9+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
10+
#include "triton/Dialect/TritonGPU/IR/Traits.h"
711
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
812
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
913
#include "llvm/ADT/EquivalenceClasses.h"
@@ -175,30 +179,92 @@ static TMemChunk allocFirstFit(MemoryBitMap &memoryMap,
175179
return chunk;
176180
}
177181

178-
static Operation *getAlloc(Value value) {
179-
while (true) {
180-
if (auto allocOp = value.getDefiningOp<TMEMAllocOp>())
181-
return allocOp;
182-
if (auto indexOp = value.getDefiningOp<ttg::MemDescIndexOp>()) {
183-
value = indexOp.getSrc();
182+
static SmallVector<Operation *> getAlloc(Value value) {
183+
SmallVector<Operation *> allocs;
184+
DenseSet<Value> seen;
185+
SmallVector<Value> worklist{value};
186+
187+
while (!worklist.empty()) {
188+
Value v = worklist.pop_back_val();
189+
if (!seen.insert(v).second)
184190
continue;
185-
}
186-
if (auto reinterpOp = value.getDefiningOp<ttg::MemDescReinterpretOp>()) {
187-
value = reinterpOp.getSrc();
191+
192+
// Handle block arguments.
193+
if (auto arg = dyn_cast<BlockArgument>(v)) {
194+
Block *block = arg.getOwner();
195+
Operation *parentOp = block->getParentOp();
196+
197+
// Handle block with predecessors.
198+
if (!block->isEntryBlock()) {
199+
for (Block *pred : block->getPredecessors()) {
200+
Operation *predOp = pred->getTerminator();
201+
auto br = dyn_cast<BranchOpInterface>(predOp);
202+
if (!br) {
203+
llvm::report_fatal_error("unhandled branch op: " +
204+
predOp->getName().getStringRef());
205+
}
206+
SmallVector<Attribute> operands(br->getNumOperands());
207+
auto it = llvm::find(br->getSuccessors(), block);
208+
unsigned idx = std::distance(br->getSuccessors().begin(), it);
209+
SuccessorOperands args = br.getSuccessorOperands(idx);
210+
Value operand =
211+
args.getForwardedOperands()[arg.getArgNumber() -
212+
args.getProducedOperandCount()];
213+
worklist.push_back(operand);
214+
}
215+
continue;
216+
}
217+
218+
// Handle region entry arguments.
219+
if (auto wsOp = dyn_cast<ttg::WarpSpecializePartitionsOp>(parentOp)) {
220+
worklist.push_back(
221+
wsOp.getParentOp().getExplicitCaptures()[arg.getArgNumber()]);
222+
} else if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
223+
unsigned idx = arg.getArgNumber() - 1;
224+
worklist.push_back(forOp.getYieldedValues()[idx]);
225+
worklist.push_back(forOp.getInits()[idx]);
226+
} else if (auto whileOp = dyn_cast<scf::WhileOp>(parentOp)) {
227+
unsigned idx = arg.getArgNumber();
228+
if (arg.getParentRegion() == &whileOp.getAfter()) {
229+
worklist.push_back(whileOp.getConditionOp().getArgs()[idx]);
230+
} else {
231+
worklist.push_back(whileOp.getYieldedValues()[idx]);
232+
worklist.push_back(whileOp.getInits()[idx]);
233+
}
234+
} else {
235+
llvm::report_fatal_error(
236+
"unhandled parent op when looking for TMEM alloc: " +
237+
parentOp->getName().getStringRef());
238+
}
188239
continue;
189240
}
190-
if (auto slice = value.getDefiningOp<TMEMSubSliceOp>()) {
191-
value = slice.getSrc();
192-
continue;
241+
242+
Operation *defOp = v.getDefiningOp();
243+
unsigned idx = cast<OpResult>(v).getResultNumber();
244+
if (isa<TMEMAllocOp>(defOp)) {
245+
allocs.push_back(defOp);
246+
} else if (defOp->hasTrait<OpTrait::MemDescViewTrait>()) {
247+
worklist.push_back(defOp->getOperand(0));
248+
} else if (auto sliceOp = dyn_cast<TMEMSubSliceOp>(defOp)) {
249+
worklist.push_back(sliceOp.getSrc());
250+
} else if (auto selectOp = dyn_cast<arith::SelectOp>(defOp)) {
251+
worklist.push_back(selectOp.getTrueValue());
252+
worklist.push_back(selectOp.getFalseValue());
253+
} else if (auto ifOp = dyn_cast<scf::IfOp>(defOp)) {
254+
worklist.push_back(ifOp.thenYield().getOperand(idx));
255+
worklist.push_back(ifOp.elseYield().getOperand(idx));
256+
} else if (auto forOp = dyn_cast<scf::ForOp>(defOp)) {
257+
worklist.push_back(forOp.getYieldedValues()[idx]);
258+
worklist.push_back(forOp.getInits()[idx]);
259+
} else if (auto whileOp = dyn_cast<scf::WhileOp>(defOp)) {
260+
worklist.push_back(whileOp.getConditionOp().getArgs()[idx]);
261+
} else {
262+
llvm::report_fatal_error("unhandled op when looking for TMEM alloc: " +
263+
defOp->getName().getStringRef());
193264
}
194-
auto arg = dyn_cast<BlockArgument>(value);
195-
if (!arg || !isa<triton::gpu::WarpSpecializePartitionsOp>(
196-
arg.getOwner()->getParentOp()))
197-
llvm::report_fatal_error("expected to find a TMEM alloc op");
198-
auto partitions = cast<triton::gpu::WarpSpecializePartitionsOp>(
199-
arg.getOwner()->getParentOp());
200-
value = partitions.getParentOp().getExplicitCaptures()[arg.getArgNumber()];
201265
}
266+
267+
return allocs;
202268
}
203269

204270
class RowIdConstraints {
@@ -245,8 +311,11 @@ allocateTMem(Operation *parentOp,
245311
if (allocSize.numRows == 64) {
246312
// HW restriction, the A alloc and accumulator needs to be in the same
247313
// rows.
248-
rowIdConstraints.joinOps(getAlloc(mmaOp.getA()),
249-
getAlloc(mmaOp.getAccumulator()));
314+
SmallVector<Operation *> lhsAllocs = getAlloc(mmaOp.getA());
315+
SmallVector<Operation *> accAllocs = getAlloc(mmaOp.getAccumulator());
316+
for (Operation *lhsAlloc : lhsAllocs)
317+
for (Operation *accAlloc : accAllocs)
318+
rowIdConstraints.joinOps(lhsAlloc, accAlloc);
250319
} else {
251320
// TODO: we need to handle cases where the format is blockM and we
252321
// have multiple blocks.

python/test/unit/language/test_core.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,24 +1969,33 @@ def serialized_add(data, Lock, SEM: tl.constexpr):
19691969

19701970

19711971
@pytest.mark.interpreter
1972-
@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed'])
1972+
@pytest.mark.parametrize("sem", [None, "acquire", "release", "acq_rel", "relaxed"])
19731973
@pytest.mark.parametrize("num_ctas", num_ctas_list)
1974-
def test_tensor_atomic_cas(sem, num_ctas, device):
1974+
@pytest.mark.parametrize("size", [4, 128, 512])
1975+
@pytest.mark.parametrize("dtype_str", ['bfloat16', 'float16', 'float32', 'uint64', 'int64', 'float64'])
1976+
def test_tensor_atomic_cas(sem, size, dtype_str, num_ctas, device):
1977+
check_type_supported(dtype_str, device)
1978+
if "float" in dtype_str and is_hip():
1979+
pytest.skip("HIP does not support atomic cas with float types")
19751980

19761981
@triton.jit
1977-
def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
1982+
def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr, dtype: tl.constexpr):
19781983
pid = tl.program_id(axis=0)
19791984
block_start = pid * BLOCK_SIZE
19801985
offsets = block_start + tl.arange(0, BLOCK_SIZE)
1981-
t1 = tl.full((BLOCK_SIZE, ), 0, dtype=tl.int64)
1982-
t2 = tl.full((BLOCK_SIZE, ), 2, dtype=tl.int64)
1986+
t1 = tl.full((BLOCK_SIZE, ), 0, dtype=dtype)
1987+
t2 = tl.full((BLOCK_SIZE, ), 2, dtype=dtype)
19831988
tl.atomic_cas(X + offsets, t1, t2, sem=sem)
19841989

1985-
X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64)
1986-
Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64)
1990+
torch_dtype = getattr(torch, dtype_str)
1991+
X = torch.zeros((size, ), device=device, dtype=torch_dtype)
1992+
X[1::2] = 1
1993+
Y = X.clone()
1994+
Y[0::2] = 2
19871995

1988-
change_value[(2, )](X, 4, sem)
1989-
assert (torch.equal(X, Y))
1996+
tl_dtype = getattr(tl, dtype_str)
1997+
change_value[(2, )](X, BLOCK_SIZE=size // 2, sem=sem, dtype=tl_dtype)
1998+
assert torch.equal(X, Y)
19901999

19912000

19922001
@pytest.mark.interpreter

python/triton/experimental/gluon/language/nvidia/hopper/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_
2626
2727
Args:
2828
a (tensor or shared_memory_descriptor): Left hand side operand.
29-
b (tensor or shared_memory_descriptor): Right hand side operand.
29+
b (shared_memory_descriptor): Right hand side operand.
3030
acc (tensor): Accumulator tensor.
3131
use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True.
3232
precision (str, optional): Dot input precision. Defaults to builder default.

python/triton/knobs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,6 @@ class amd_knobs(base_knobs):
518518
use_buffer_atomics: env_bool = env_bool("AMDGCN_USE_BUFFER_ATOMICS", True)
519519
dump_amdgcn: env_bool = env_bool("AMDGCN_ENABLE_DUMP")
520520
libhip_path: env_opt_str = env_opt_str("TRITON_LIBHIP_PATH")
521-
lld_path: env_opt_str = env_opt_str("TRITON_HIP_LLD_PATH")
522521

523522
# We use strs so that we can have a default value based on other runtime info
524523
use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG")

python/triton/runtime/interpreter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,18 @@ def validate(self):
8888
assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
8989
assert len(self.strides) == self.ndim
9090
assert len(self.block_shape) == self.ndim
91+
assert self.ndim >= 1, "descriptor cannot be 0 dimensional"
9192

9293
for stride in self.strides[:-1]:
9394
assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned"
9495
assert self.strides[-1].data.item() == 1, "last dim must be contiguous"
96+
for i in range(self.ndim - 1):
97+
stride = self.strides[i].data.item()
98+
prev_stride = self.strides[i + 1].data.item()
99+
prev_size = self.shape[i + 1].data.item()
100+
assert stride >= prev_stride, "strides must be ordered largest to smallest"
101+
assert (stride % prev_stride) == 0, "strides must be even multiples of smaller strides"
102+
assert (stride // prev_stride) >= prev_size, "invalid stride"
95103

96104
def materialize_pointers(self, offsets: List[TensorHandle]):
97105
assert len(offsets) == self.ndim

test/TritonGPU/amd/accelerate-amd-matmul-mfma-gfx950.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,74 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
223223
tt.return
224224
}
225225
}
226+
227+
// -----
228+
229+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
230+
#blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
231+
// CHECK{LITERAL}: #shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [1, 0]}>
232+
// CHECK-LABEL: mfma_dot_scaled_mxfp4_b_packed_mn
233+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
234+
tt.func public @mfma_dot_scaled_mxfp4_b_packed_mn(
235+
%a: tensor<128x128xf8E5M2, #blocked>,
236+
%b: tensor<128x64xi8, #blocked1>,
237+
%c: tensor<128x128xf32, #blocked>,
238+
%arg4: tensor<128x128x!tt.ptr<f32>, #blocked>
239+
) {
240+
%b1 = ttg.convert_layout %b : tensor<128x64xi8, #blocked1> -> tensor<128x64xi8, #blocked>
241+
// CHECK: %[[ALLOCB:.+]] = ttg.local_alloc {{.*}} : (tensor<128x64xi8, #blocked>) -> !ttg.memdesc<128x64xi8, #shared, #smem>
242+
// CHECK: %[[B:.+]] = amdgpu.local_load_packed_tranposed %[[ALLOCB]] : !ttg.memdesc<128x64xi8, #shared, #smem> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
243+
// CHECK: tt.dot_scaled %{{.*}}, %[[B]], %{{.*}} lhs = e5m2 rhs = e2m1 {fastMath = false}
244+
%accumulator_52 = tt.dot_scaled %a, %b1, %c lhs = e5m2 rhs = e2m1 {fastMath = false, rhs_k_pack = false} : tensor<128x128xf8E5M2, #blocked> * tensor<128x64xi8, #blocked> -> tensor<128x128xf32, #blocked>
245+
tt.store %arg4, %accumulator_52 : tensor<128x128x!tt.ptr<f32>, #blocked>
246+
tt.return
247+
}
248+
}
249+
// -----
250+
251+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
252+
#blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
253+
// CHECK{LITERAL}: #shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [0, 1]}>
254+
// CHECK-LABEL: mfma_dot_scaled_mxfp4_a_packed_mn
255+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
256+
tt.func public @mfma_dot_scaled_mxfp4_a_packed_mn(
257+
%a: tensor<64x128xi8, #blocked>,
258+
%b: tensor<128x128xf8E5M2, #blocked1>,
259+
%c: tensor<128x128xf32, #blocked>,
260+
%arg4: tensor<128x128x!tt.ptr<f32>, #blocked>
261+
) {
262+
%b1 = ttg.convert_layout %b : tensor<128x128xf8E5M2, #blocked1> -> tensor<128x128xf8E5M2, #blocked>
263+
// CHECK: %[[ALLOCA:.+]] = ttg.local_alloc {{.*}} : (tensor<64x128xi8, #blocked>) -> !ttg.memdesc<64x128xi8, #shared, #smem>
264+
// CHECK: %[[A:.+]] = amdgpu.local_load_packed_tranposed %[[ALLOCA]] : !ttg.memdesc<64x128xi8, #shared, #smem> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
265+
// CHECK: tt.dot_scaled %[[A]], %{{.*}}, %{{.*}} lhs = e2m1 rhs = e5m2 {fastMath = false}
266+
%accumulator_52 = tt.dot_scaled %a, %b1, %c lhs = e2m1 rhs = e5m2 {fastMath = false, lhs_k_pack = false} : tensor<64x128xi8, #blocked> * tensor<128x128xf8E5M2, #blocked> -> tensor<128x128xf32, #blocked>
267+
tt.store %arg4, %accumulator_52 : tensor<128x128x!tt.ptr<f32>, #blocked>
268+
tt.return
269+
}
270+
}
271+
272+
// -----
273+
274+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
275+
#blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
276+
// CHECK{LITERAL}: #shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [0, 1]}>
277+
// CHECK{LITERAL}: #shared1 = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [1, 0]}>
278+
// CHECK-LABEL: mfma_dot_scaled_mxfp4_ab_packed_mn
279+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
280+
tt.func public @mfma_dot_scaled_mxfp4_ab_packed_mn(
281+
%a: tensor<64x128xi8, #blocked>,
282+
%b: tensor<128x64xi8, #blocked1>,
283+
%c: tensor<128x128xf32, #blocked>,
284+
%arg4: tensor<128x128x!tt.ptr<f32>, #blocked>
285+
) {
286+
%b1 = ttg.convert_layout %b : tensor<128x64xi8, #blocked1> -> tensor<128x64xi8, #blocked>
287+
// CHECK: %[[ALLOCA:.+]] = ttg.local_alloc {{.*}} : (tensor<64x128xi8, #blocked>) -> !ttg.memdesc<64x128xi8, #shared, #smem>
288+
// CHECK: %[[A:.+]] = amdgpu.local_load_packed_tranposed %[[ALLOCA]] : !ttg.memdesc<64x128xi8, #shared, #smem> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
289+
// CHECK: %[[ALLOCB:.+]] = ttg.local_alloc {{.*}} : (tensor<128x64xi8, #blocked>) -> !ttg.memdesc<128x64xi8, #shared1, #smem>
290+
// CHECK: %[[B:.+]] = amdgpu.local_load_packed_tranposed %[[ALLOCB]] : !ttg.memdesc<128x64xi8, #shared1, #smem> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
291+
// CHECK: tt.dot_scaled %[[A]], %[[B]], %{{.*}} lhs = e2m1 rhs = e2m1 {fastMath = false}
292+
%accumulator_52 = tt.dot_scaled %a, %b1, %c lhs = e2m1 rhs = e2m1 {fastMath = false, lhs_k_pack = false, rhs_k_pack = false} : tensor<64x128xi8, #blocked> * tensor<128x64xi8, #blocked> -> tensor<128x128xf32, #blocked>
293+
tt.store %arg4, %accumulator_52 : tensor<128x128x!tt.ptr<f32>, #blocked>
294+
tt.return
295+
}
296+
}

0 commit comments

Comments
 (0)