Skip to content

Commit 2dbc39c

Browse files
Merge commit '6c3e9535c44774dfd56357acba9c2183b247f58e'
2 parents 9c52bc3 + 6c3e953 commit 2dbc39c

File tree

44 files changed

+1486
-406
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1486
-406
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ env:
2525
TRITON_BUILD_WITH_CLANG_LLD: "TRUE"
2626
TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE"
2727
TRITON_DISABLE_LINE_INFO: 1
28+
PROTON_SKIP_PC_SAMPLING_TEST: 1
2829
jobs:
2930
Runner-Preparation:
3031
runs-on: ubuntu-latest

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ env:
2727
TRITON_BUILD_WITH_CLANG_LLD: "TRUE"
2828
TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE"
2929
TRITON_DISABLE_LINE_INFO: 1
30-
30+
PROTON_SKIP_PC_SAMPLING_TEST: 1
3131

3232
jobs:
3333
Runner-Preparation:

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,6 @@ class TargetInfoBase {
5757
unsigned numLaneToReduce,
5858
unsigned interleave) const = 0;
5959

60-
// TODO (Keren): Remove this function once layout conversion using stmatrix is
61-
// handled by Linear Layout.
62-
virtual bool processReplicaUsingStMatrix(
63-
RewriterBase &rewriter, Location loc, Value smemBase,
64-
SmallVector<Value> &vals, RankedTensorType srcTy, Type elemTy,
65-
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> origRepShape,
66-
ArrayRef<unsigned> outOrd, unsigned accumNumReplicates,
67-
int swizzleByteWidth = 0) const = 0;
68-
6960
virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
7061
// Emits LLVM code with |rewriter| to print a message following the given
7162
// format from the device. |formatStrStart| is the pointer to the start of

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 125 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,134 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
113113
// row0 reg[0-1] reg[4-5]
114114
// row8 reg[2-3] reg[6-7]
115115
//
116+
// When `swizzleByteSize` is non-zero, the layout is constructed
117+
// differently due to leading dimension offset and swizzling.
118+
// There are two key concepts to understand:
119+
//
120+
// 1. Chunks: The leading dimension (i.e., the column dimension) is divided
121+
// into chunks, where each chunk's size is determined by `swizzleByteSize`.
122+
// 2. Swizzling within tiles: Each tile applies a swizzling pattern to its
123+
// rows to optimize memory access.
124+
//
125+
// - Concept 1: Chunks
126+
//
127+
// In the swizzled layout, the leading dimension is strided by
128+
// `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk
129+
// spans a certain number of columns.
130+
//
131+
// For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16
132+
// bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16
133+
// elements * 2 bytes per element = 32 bytes per row).
134+
//
135+
// Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be
136+
// calculated as:
137+
//
138+
// Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes /
139+
// 32 bytes = 4 tiles
140+
//
141+
// Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns
142+
// (since each tile is 16 columns):
143+
//
144+
// col0-15 col16-31 col32-47 col48-63
145+
// row0-15 tile0 tile1 tile2 tile3
146+
//
147+
// For a tensor of size 128x128 elements (#rows x #columns), and each element
148+
// being 16 bits, the tensor can be divided into multiple chunks both
149+
// horizontally and vertically. Chunks are stored in memory in a "column-major"
150+
// order based on chunks, meaning chunk1's address follows chunk0's.
151+
//
152+
// Assuming we have 8 warps, and we assign each warp to process a chunk of 16
153+
// rows (rows per tile) and 128 columns (the width of two chunks). This results
154+
// in each warp handling one horizontal slice of the tensor.
155+
//
156+
// The overall layout can be visualized as:
157+
//
158+
// |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->|
159+
// columns 0-63 columns 64-127
160+
// warp0 | rows 0-15 chunk0 chunk8
161+
// warp1 | rows 16-31 chunk1 chunk9
162+
// warp2 | rows 32-47 chunk2 chunk10
163+
// warp3 | rows 48-63 chunk3 chunk11
164+
// warp4 | rows 64-79 chunk4 chunk12
165+
// warp5 | rows 80-95 chunk5 chunk13
166+
// warp6 | rows 96-111 chunk6 chunk14
167+
// warp7 | rows 112-127 chunk7 chunk15
168+
//
169+
// - Concept 2: Swizzling within tiles
170+
//
171+
// Within each 16x16 tile, rows are swizzled to optimize memory access patterns.
172+
// This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the
173+
// level of each 16x16 tile rather than the entire tensor.
174+
//
175+
// Key parameters for swizzling:
176+
//
177+
// - `perPhase`: The number of rows over which to apply a XOR operation at
178+
// each phase.
179+
// - `maxPhase`: The total number of phases.
180+
// - `vectorWidth`: The number of elements per vector, which is 8 in this case
181+
// because `stmatrix` stores 8 contiguous elements per thread.
182+
//
183+
// The offset of each element within a tile is calculated using the formula:
184+
//
185+
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
186+
// maxPhase)) * elementSize
187+
//
188+
// where `elementSize` is the size of each element in bytes (2 bytes for 16-bit
189+
// elements).
190+
//
191+
// For example, consider the element at index `(row=1, col=0)` in chunk0:
192+
//
193+
// Without swizzling:
194+
//
195+
// offset = row * swizzleByteSize + col * elementSize
196+
// = 1 * 128 bytes + 0 * 2 bytes
197+
// = 128 bytes
198+
//
199+
// With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`):
200+
//
201+
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
202+
// maxPhase)) * elementSize
203+
// = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes
204+
// = 128 bytes + (8 * (1 % 8)) * 2 bytes
205+
// = 128 bytes + 8 * 2 bytes
206+
// = 128 bytes + 16 bytes
207+
// = 144 bytes
208+
//
209+
// This swizzling ensures that elements are stored in a way that optimizes for
210+
// memory bandwidth and reduces bank conflicts.
211+
//
212+
// - Verification through Linear Layout
213+
//
214+
// We can verify the offsets with the following outputs of the corresponding
215+
// linear layout, where each element is 16 bits (2 bytes):
216+
//
217+
// - register=1 -> offset=1
218+
// register=2 -> offset=2
219+
// register=4 -> offset=4
220+
// register=8 -> offset=16
221+
// register=16 -> offset=32
222+
// register=32 -> offset=8192
223+
// - lane=1 -> offset=72
224+
// lane=2 -> offset=144
225+
// lane=4 -> offset=288
226+
// lane=8 -> offset=512
227+
// lane=16 -> offset=8
228+
// - warp=1 -> offset=1024
229+
// warp=2 -> offset=2048
230+
// warp=4 -> offset=4096
231+
//
232+
// For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in
233+
// `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result
234+
// matches our earlier calculation.
235+
//
116236
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
117237
// bit width of the tensor in the future to support more flexible tensor
118238
// encodings
119-
std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
120-
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
121-
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order);
239+
std::optional<LinearLayout>
240+
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
241+
ArrayRef<unsigned> repShape,
242+
ArrayRef<unsigned> paddedRepShape,
243+
ArrayRef<unsigned> order, int swizzleByteSize);
122244
} // namespace mlir::triton::gpu
123245

124246
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,9 @@ struct ConvertLayoutOpConversion
215215
if (repId != 0) {
216216
barrier();
217217
}
218-
auto successful = targetInfo.processReplicaUsingStMatrix(
219-
rewriter, loc, smemBase, vals, srcTy,
220-
getTypeConverter()->convertType(srcTy.getElementType()),
221-
paddedRepShape, origRepShape, outOrd, accumNumReplicates);
222-
if (!successful) {
223-
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
224-
multiDimRepId, inVec, paddedRepShape, origRepShape,
225-
outOrd, vals, smemBase);
226-
}
218+
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
219+
multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd,
220+
vals, smemBase);
227221
barrier();
228222
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
229223
multiDimRepId, outVec, paddedRepShape, origRepShape,
@@ -483,9 +477,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
483477
// Input dims: [reg, lane, warp]
484478
// Output dims: [offset, iteration]
485479
std::optional<LinearLayout> shmemStoreLayout =
486-
chooseStMatrixLayoutForRegToRegConversion(
487-
ctx, op.getSrc().getType(), scratchConfig.repShape,
488-
scratchConfig.paddedRepShape, scratchConfig.order);
480+
chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape,
481+
scratchConfig.paddedRepShape, scratchConfig.order,
482+
/*swizzleByteSize=*/0);
489483
bool isStMatrix = shmemStoreLayout.has_value();
490484
if (!isStMatrix) {
491485
shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout);

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
116116
RankedTensorType dstTy = op.getType();
117117
Attribute srcLayout = srcTy.getEncoding();
118118
Attribute dstLayout = dstTy.getEncoding();
119-
// TODO: do we need to check if src is shared ?
120119
if (isa<SharedEncodingAttr>(srcLayout) &&
121120
isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
122121
dstLayout)) {

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -820,8 +820,8 @@ namespace {
820820
// stmatrix. These restrictions are retained from legacy code, and we could
821821
// relax some of them in the future.
822822
bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
823-
ArrayRef<unsigned> paddedRepShape,
824-
ArrayRef<unsigned> order) {
823+
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
824+
int swizzleByteSize) {
825825
auto mmaLayout =
826826
mlir::dyn_cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
827827
if (!mmaLayout || !mmaLayout.isHopper())
@@ -840,17 +840,87 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
840840
return false;
841841
if (paddedRepShape[1] % 8 != 0)
842842
return false;
843+
if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 &&
844+
swizzleByteSize != 128)
845+
return false;
843846
return true;
844847
}
845848

846-
} // anonymous namespace
849+
std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
850+
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
851+
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
852+
int swizzleByteSize) {
853+
StringAttr kReg = S("register");
854+
StringAttr kLane = S("lane");
855+
StringAttr kWarp = S("warp");
856+
StringAttr kCol = S("dim1");
857+
StringAttr kRow = S("dim0");
858+
StringAttr kOffset = S("offset");
859+
860+
int perPhase;
861+
int maxPhase;
862+
if (swizzleByteSize == 32) {
863+
perPhase = 4;
864+
maxPhase = 2;
865+
} else if (swizzleByteSize == 64) {
866+
perPhase = 2;
867+
maxPhase = 4;
868+
} else if (swizzleByteSize == 128) {
869+
perPhase = 1;
870+
maxPhase = 8;
871+
} else {
872+
llvm::errs() << "Illegal swizzleByteSize: " << swizzleByteSize << "\n";
873+
llvm::report_fatal_error("Illegal swizzleByteSize");
874+
}
875+
876+
// stmatrix only supports 16-bit elements, and each vector has 8 elements
877+
int elemBitWidth = 16;
878+
int vecSize = 8;
879+
int numRows = 16;
880+
int numCols = 8 * swizzleByteSize / elemBitWidth;
881+
882+
// Construct a single stmatrix.x4 (16x16) tile
883+
std::vector<std::vector<int>> basesReg = {{1, 0}, {2, 0}, {4, 0}};
884+
std::vector<std::vector<int>> basesLane;
885+
for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) {
886+
int row = 1 << logRow;
887+
basesLane.push_back({vecSize * ((row / perPhase) % maxPhase), row});
888+
}
889+
basesLane.push_back({8, 0});
890+
891+
// Expand the tile's register dimension to fit swizzleByteSize, which is a
892+
// "chunk"
893+
for (int logChunk = 0; logChunk < llvm::Log2_32(numCols / 16); logChunk++) {
894+
int chunk = 1 << logChunk;
895+
basesReg.push_back({16 * chunk, 0});
896+
}
897+
898+
// Construct the layout for a single chunk
899+
LinearLayout layout =
900+
LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow});
847901

848-
std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
902+
// Expand the `warp` dimension according to warpsPerCTA.
903+
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
904+
layout *=
905+
identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol})
906+
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
907+
908+
// Expand the `register` dimension so the size of columns matches `n`.
909+
int n = mma.getInstrShape()[1];
910+
int numWarpRows = layout.getOutDimSize(kRow);
911+
layout = (layout.reshapeOuts({{kOffset, layout.getTotalOutDimSize()}}) *
912+
LinearLayout::identity1D(n / numCols, kReg, kOffset))
913+
.reshapeOuts({{kCol, n}, {kRow, numWarpRows}});
914+
915+
auto ret =
916+
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
917+
return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames()))
918+
.reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}});
919+
}
920+
921+
std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
849922
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
850923
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order) {
851-
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order))
852-
return std::nullopt;
853-
854924
StringAttr kReg = S("register");
855925
StringAttr kLane = S("lane");
856926
StringAttr kWarp = S("warp");
@@ -880,4 +950,23 @@ std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
880950
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
881951
}
882952

953+
} // anonymous namespace
954+
955+
std::optional<LinearLayout>
956+
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
957+
ArrayRef<unsigned> repShape,
958+
ArrayRef<unsigned> paddedRepShape,
959+
ArrayRef<unsigned> order, int swizzleByteSize) {
960+
if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order,
961+
swizzleByteSize))
962+
return std::nullopt;
963+
964+
if (swizzleByteSize == 0)
965+
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape,
966+
paddedRepShape, order);
967+
else
968+
return chooseStMatrixLayoutLeadingOffset(
969+
ctx, tensorTy, repShape, paddedRepShape, order, swizzleByteSize);
970+
}
971+
883972
} // namespace mlir::triton::gpu

python/setup.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,14 @@ def get_entry_points():
638638
return entry_points
639639

640640

641+
def get_git_commit_hash(length=8):
642+
try:
643+
cmd = ['git', 'rev-parse', f'--short={length}', 'HEAD']
644+
return "+git{}".format(subprocess.check_output(cmd).strip().decode('utf-8'))
645+
except Exception:
646+
return ""
647+
648+
641649
def get_install_requires():
642650
install_requires = [
643651
"packaging", # used by third_party/intel/backend/compiler.py
@@ -647,7 +655,7 @@ def get_install_requires():
647655

648656
setup(
649657
name=os.environ.get("TRITON_WHEEL_NAME", "triton"),
650-
version="3.0.0" + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""),
658+
version="3.0.0" + get_git_commit_hash() + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""),
651659
author="Philippe Tillet",
652660
author_email="[email protected]",
653661
description="A language and compiler for custom Deep Learning operations",
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline | FileCheck %s
2+
3+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
5+
#loc = loc("/data/users/dberard/triton-env/scripts/matmul.py":6:0)
6+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}>
7+
module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
8+
// CHECK-LABEL: tt.func @use_dep_args
9+
tt.func @use_dep_args(%a_ptrs: tensor<64x32x!tt.ptr<bf16>, #blocked>, %b_ptrs: tensor<32x64x!tt.ptr<bf16>, #blocked1>, %loop_range: i32) -> (tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr<bf16>, #blocked>, tensor<32x64x!tt.ptr<bf16>, #blocked1>) {
10+
%cst = arith.constant dense<32> : tensor<64x32xi32, #blocked>
11+
%cst2 = arith.constant dense<2048> : tensor<32x64xi32, #blocked1>
12+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
13+
%c0_i32 = arith.constant 0 : i32
14+
%c8_i32 = arith.constant 8 : i32
15+
%c32_i32 = arith.constant 32 : i32
16+
// CHECK: tt.load
17+
// CHECK: [[FOR_OUT:%[a-z0-9_]+]]:{{[0-9]+}} = scf.for
18+
%for:3 = scf.for %arg6 = %c0_i32 to %loop_range step %c32_i32 iter_args(%arg7 = %cst_0, %arg8 = %a_ptrs, %arg9 = %b_ptrs) -> (tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr<bf16>, #blocked>, tensor<32x64x!tt.ptr<bf16>, #blocked1>) : i32 {
19+
%63 = tt.load %arg8 : tensor<64x32x!tt.ptr<bf16>, #blocked>
20+
%64 = tt.load %arg9 : tensor<32x64x!tt.ptr<bf16>, #blocked1>
21+
%65 = triton_gpu.convert_layout %63 : tensor<64x32xbf16, #blocked> -> tensor<64x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
22+
%66 = triton_gpu.convert_layout %64 : tensor<32x64xbf16, #blocked1> -> tensor<32x64xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
23+
%67 = tt.dot %65, %66, %arg7 : tensor<64x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma>
24+
%68 = tt.addptr %arg8, %cst : tensor<64x32x!tt.ptr<bf16>, #blocked>, tensor<64x32xi32, #blocked>
25+
%69 = tt.addptr %arg9, %cst2 : tensor<32x64x!tt.ptr<bf16>, #blocked1>, tensor<32x64xi32, #blocked1>
26+
scf.yield %67, %68, %69 : tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr<bf16>, #blocked>, tensor<32x64x!tt.ptr<bf16>, #blocked1>
27+
}
28+
// CHECK: tt.return {{[^,]+}}, [[FOR_OUT]]#3, [[FOR_OUT]]#4
29+
tt.return %for#0, %for#1, %for#2 : tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr<bf16>, #blocked>, tensor<32x64x!tt.ptr<bf16>, #blocked1>
30+
}
31+
}

0 commit comments

Comments
 (0)