Skip to content

Commit eeed3f5

Browse files
Fix build and test failures from 593a1b5
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 861f217 commit eeed3f5

File tree

6 files changed

+13
-22
lines changed

6 files changed

+13
-22
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ The semantic of this `tt.dot` includes GEMM tiling configuration as:
184184
"unsigned":$systolicDepth,
185185
"unsigned":$executionSize,
186186
"unsigned":$opsPerChannel,
187-
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
187+
ArrayRefParameter<"unsigned">:$warpsPerCTA,
188188
ArrayRefParameter<"unsigned">:$repCluster,
189189
"unsigned":$threadsPerWarp
190190
);

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape,
169169
OpIdx opIdx) const {
170170
// Always return a 3D shape repetitions for the ease of value handling, same
171171
// to mma.
172-
SmallVector<unsigned> warpsPerCTA = getWarpsPerCTA();
172+
auto warpsPerCTA = getWarpsPerCTA();
173173
size_t rank = shape.size();
174174
SmallVector<int64_t> rep(3, 1);
175175
switch (opIdx) {
@@ -239,11 +239,6 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
239239
llvm_unreachable("unexpected opIdx");
240240
}
241241

242-
SmallVector<unsigned> DpasEncodingAttr::getWarpsPerCTA() const {
243-
return SmallVector<unsigned>(getWarpsPerCTA__().begin(),
244-
getWarpsPerCTA__().end());
245-
}
246-
247242
SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() const {
248243
size_t rank = getWarpsPerCTA().size();
249244
assert(rank == 2 || rank == 3);
@@ -295,7 +290,7 @@ unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) {
295290
LogicalResult DpasEncodingAttr::verify(
296291
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
297292
unsigned repeatCount, unsigned systolicDepth, unsigned executionSize,
298-
unsigned opsPerChan, ::llvm::ArrayRef<unsigned> warpsPerCTA__,
293+
unsigned opsPerChan, ::llvm::ArrayRef<unsigned> warpsPerCTA,
299294
::llvm::ArrayRef<unsigned> repCluster, unsigned sugGroupSize) {
300295
if (repeatCount > 8 || repeatCount < 1) {
301296
return emitError() << "repeatCount must be in the range [1, 8], but was:"
@@ -378,7 +373,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
378373
ArrayRef<unsigned> rB = shapeB;
379374
SmallVector<unsigned> shapeC = getShapeC();
380375
ArrayRef<unsigned> rC = shapeC;
381-
SmallVector<unsigned> warpsPerCTA = getWarpsPerCTA();
376+
auto warpsPerCTA = getWarpsPerCTA();
382377
ArrayRef<unsigned> repCluster = getRepCluster();
383378
printer << "<{"
384379
<< "repeatCount = " << getRepeatCount() << ", "
@@ -436,10 +431,6 @@ SmallVector<unsigned> WarpEncodingAttr::getRepOrder() const {
436431
llvm::report_fatal_error("NYI. WarpEncodingAttr::getRepOrder");
437432
}
438433

439-
SmallVector<unsigned> WarpEncodingAttr::getWarpsPerCTA() const {
440-
llvm::report_fatal_error("NYI. WarpEncodingAttr::getWarpsPerCTA");
441-
}
442-
443434
LinearLayout WarpEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
444435
llvm::report_fatal_error("NYI. WarpEncodingAttr::toLinearLayout");
445436
}

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
418418
StringAttr kLane = S("lane");
419419
StringAttr kWarp = S("warp");
420420

421-
const SmallVector<unsigned> warpsPerCTA = dpas.getWarpsPerCTA();
421+
auto warpsPerCTA = dpas.getWarpsPerCTA();
422422
int threadsPerWarp = product<unsigned>(dpas.getThreadsPerWarp());
423423
unsigned opsPerChannel = dpas.getOpsPerChannel();
424424
auto repCluster = dpas.getRepCluster();

third_party/intel/lib/TritonIntelGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,8 @@ static SmallVector<Value> computeCrossWarpHistogram(
8585
Value threadId, int numWarps) {
8686
auto b = TritonLLVMOpBuilder(loc, rewriter);
8787
SmallVector<Value> histogramValues;
88-
unsigned numWarpsWithUniqueData =
89-
mlir::triton::gpu::getWarpsPerCTAWithUniqueData(srcType.getEncoding(),
90-
srcType.getShape())[0];
88+
unsigned numWarpsWithUniqueData = mlir::triton::gpu::getWarpsPerCTA(
89+
srcType.getEncoding(), srcType.getShape())[0];
9190
Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1));
9291
// Initialize the shared memory with zeros.
9392
int64_t numElementPerThread =

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ struct LoadOpToBlockIOConversion
555555
unsigned numElems = getTotalElemsPerThread(resultType);
556556
SmallVector<int64_t> numReps =
557557
dpasLayout.getDPASRepetitions(tensorShape, opIdx);
558-
const SmallVector<unsigned> warpsPerCTA = dpasLayout.getWarpsPerCTA();
558+
auto warpsPerCTA = dpasLayout.getWarpsPerCTA();
559559
SmallVector<unsigned> dpasWarpsOrder =
560560
getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true);
561561
unsigned threadsPerWarp =
@@ -1046,7 +1046,7 @@ struct LoadOpConversion
10461046
unsigned numElems = getTotalElemsPerThread(resultType);
10471047
SmallVector<int64_t> numReps =
10481048
dpasLayout.getDPASRepetitions(tensorShape, opIdx);
1049-
const SmallVector<unsigned> warpsPerCTA = dpasLayout.getWarpsPerCTA();
1049+
auto warpsPerCTA = dpasLayout.getWarpsPerCTA();
10501050
SmallVector<unsigned> dpasWarpsOrder =
10511051
getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true);
10521052
unsigned threadsPerWarp =
@@ -1678,7 +1678,7 @@ struct StoreOpConversion
16781678
size_t rank = tensorShape.size();
16791679
unsigned numElems = getTotalElemsPerThread(tensorType);
16801680
SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
1681-
const SmallVector<unsigned> warpsPerCTA = dpasLayout.getWarpsPerCTA();
1681+
auto warpsPerCTA = dpasLayout.getWarpsPerCTA();
16821682
SmallVector<int64_t> numReps =
16831683
dpasLayout.getDPASRepetitions(tensorShape, 2);
16841684
SmallVector<unsigned> dpasWarpsOrder =

third_party/intel/lib/TritonIntelGPUTransforms/DistributeToWarps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ SmallVector<Value> distributeOffset(const SmallVector<Value> &oldOffsets,
120120
Attribute layout = tensorType.getEncoding();
121121
if (auto dotEncoding = dyn_cast<ttg::DotOperandEncodingAttr>(layout))
122122
layout = dotEncoding.getParent();
123-
const SmallVector<unsigned> &warpsPerCTA = ttg::getWarpsPerCTA(layout);
123+
auto warpsPerCTA = cast<ttg::BlockedEncodingAttr>(layout).getWarpsPerCTA();
124124
size_t dims = warpsPerCTA.size();
125125
assert(dims <= 2 && "no more than 2D shape");
126126

@@ -211,7 +211,8 @@ void distributeMakeRangeOp(tt::MakeRangeOp op, Value warpId) {
211211
auto sliceLayout = dyn_cast<ttg::SliceEncodingAttr>(tensorTy.getEncoding());
212212
assert(sliceLayout && "Expected slice layout");
213213

214-
auto parentWarpsPerCTA = ttg::getWarpsPerCTA(sliceLayout.getParent());
214+
auto parentWarpsPerCTA =
215+
cast<ttg::BlockedEncodingAttr>(sliceLayout.getParent()).getWarpsPerCTA();
215216
assert(parentWarpsPerCTA.size() == 2 && "Only slice of 2D layout supported");
216217
assert(parentWarpsPerCTA.back() == 1 &&
217218
"Warp distribution on second dimensions unsupported");

0 commit comments

Comments
 (0)