Skip to content

Commit 3586668

Browse files
committed
Bank conflict
1 parent 6c5b353 commit 3586668

File tree

5 files changed

+66
-90
lines changed

5 files changed

+66
-90
lines changed

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,16 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">,
314314
let hasVerifier = 1;
315315
}
316316

317+
def TritonGEN_SIMDBlockMemoryAccessElementType : AnyTypeOf<[I8, I16, I32, I64]>;
318+
319+
def TritonGEN_SIMDBlockMemoryAccessType
320+
: AnyTypeOf<[TritonGEN_SIMDBlockMemoryAccessElementType,
321+
FixedVectorOfLengthAndType<[2, 4, 8], [TritonGEN_SIMDBlockMemoryAccessElementType]>,
322+
// Vectors of length 16 only allowed for i8 for now.
323+
FixedVectorOfLengthAndType<[16], [I8]>]>;
324+
317325
def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
318-
Results<(outs FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$res)>,
326+
Results<(outs TritonGEN_SIMDBlockMemoryAccessType:$res)>,
319327
Arguments<(ins
320328
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr
321329
)> {
@@ -331,14 +339,12 @@ def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
331339
let assemblyFormat = [{
332340
operands ` ` attr-dict `:` functional-type(operands, results)
333341
}];
334-
335-
let hasVerifier = 1;
336342
}
337343

338344
def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
339345
Arguments<(ins
340346
Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
341-
FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$val
347+
TritonGEN_SIMDBlockMemoryAccessType:$val
342348
)> {
343349

344350
let summary = "simd block write";
@@ -353,7 +359,5 @@ def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
353359
let assemblyFormat = [{
354360
operands ` ` attr-dict `:` `(` type(operands) `)`
355361
}];
356-
357-
let hasVerifier = 1;
358362
}
359363
#endif // TRITONGEN_OPS

third_party/intel/lib/Analysis/Allocation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
127127
unsigned warpsPerCTA = product(srcEncoding.getWarpsPerCTA());
128128
unsigned remaining = product(srcTy.getShape()) /
129129
(threadsPerWarp * threadsPerWarp * warpsPerCTA);
130-
SmallVector<unsigned> repShape{threadsPerWarp, threadsPerWarp, remaining,
131-
warpsPerCTA};
130+
SmallVector<unsigned> repShape{threadsPerWarp + 1, threadsPerWarp,
131+
remaining, warpsPerCTA};
132132
return ScratchConfig(repShape, repShape,
133133
/*inVec=*/1, /*outVec=*/threadsPerWarp);
134134
}

third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,6 @@ template <typename Op> static LogicalResult verifyMatrixInput(Op op) {
4848
return success();
4949
}
5050

51-
static LogicalResult verifySIMDBlockTy(Operation *op, VectorType vecTy) {
52-
unsigned numElems = vecTy.getNumElements();
53-
IntegerType elemTy = cast<IntegerType>(vecTy.getElementType());
54-
55-
// FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it.
56-
if (numElems != 1 && numElems != 2 && numElems != 4 && numElems != 8 &&
57-
(elemTy.getWidth() != 8 || numElems != 16))
58-
return op->emitOpError("unsupported vector type");
59-
60-
return success();
61-
}
62-
6351
//===----------------------------------------------------------------------===//
6452
// gen.sub_group_reduce
6553
//===----------------------------------------------------------------------===//
@@ -438,19 +426,3 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() {
438426

439427
return success();
440428
}
441-
442-
//===----------------------------------------------------------------------===//
443-
// gen.simdblockread
444-
//===----------------------------------------------------------------------===//
445-
446-
LogicalResult TritonGEN::SIMDBlockReadOp::verify() {
447-
return verifySIMDBlockTy(*this, getRes().getType());
448-
}
449-
450-
//===----------------------------------------------------------------------===//
451-
// gen.simdblockwrite
452-
//===----------------------------------------------------------------------===//
453-
454-
LogicalResult TritonGEN::SIMDBlockWriteOp::verify() {
455-
return verifySIMDBlockTy(*this, getVal().getType());
456-
}

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
2929

3030
#include "llvm/ADT/StringRef.h"
31+
#include "llvm/ADT/TypeSwitch.h"
3132
#include "llvm/IR/Attributes.h"
3233
#include "llvm/Support/ErrorHandling.h"
3334
#include "llvm/Support/ModRef.h"
@@ -937,25 +938,34 @@ struct TritonMatrix2DBlockPrefetchLowering
937938
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
938939
OpType, TritonGEN::SIMDBlockReadOp,
939940
TritonGEN::SIMDBlockWriteOp>::value>>
940-
static std::string getSIMDBlockManglingName(OpType op, VectorType vecTy) {
941+
static std::string getSIMDBlockManglingName(OpType op, Type type) {
941942
constexpr bool isWrite =
942943
std::is_same<OpType, TritonGEN::SIMDBlockWriteOp>::value;
943944
const LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
944-
const unsigned numElems = vecTy.getNumElements();
945945
// Note: OCL builtin name here differs from regular mangling.
946946
std::string funcName = "intel_sub_group_block_";
947947
if constexpr (isWrite)
948948
funcName += "write";
949949
else
950950
funcName += "read";
951-
funcName += "_u" + intel::getTypeMangling(vecTy.getElementType()) +
952-
(numElems == 1 ? "" : std::to_string(numElems));
953-
funcName =
954-
"_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
955-
std::to_string(ptrTy.getAddressSpace()) +
956-
intel::getTypeMangling(vecTy.getElementType(), /*isUnsigned=*/true);
951+
TypeSwitch<Type>(type)
952+
.Case([&](VectorType vecType) {
953+
const unsigned numElems = vecType.getNumElements();
954+
funcName += "_u" + intel::getTypeMangling(vecType.getElementType()) +
955+
std::to_string(numElems);
956+
funcName = "_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
957+
std::to_string(ptrTy.getAddressSpace()) +
958+
intel::getTypeMangling(vecType.getElementType(),
959+
/*isUnsigned=*/true);
960+
})
961+
.Case([&](IntegerType vecType) {
962+
funcName += "_u" + intel::getTypeMangling(type);
963+
funcName = "_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
964+
std::to_string(ptrTy.getAddressSpace()) +
965+
intel::getTypeMangling(type, /*isUnsigned=*/true);
966+
});
957967
if constexpr (isWrite)
958-
funcName += intel::getTypeMangling(vecTy, /*isUnsigned=*/true);
968+
funcName += intel::getTypeMangling(type, /*isUnsigned=*/true);
959969
return funcName;
960970
}
961971

@@ -968,17 +978,17 @@ struct TritonSIMDBlockReadLowering
968978
matchAndRewrite(TritonGEN::SIMDBlockReadOp op, OpAdaptor adaptor,
969979
ConversionPatternRewriter &rewriter) const override {
970980
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
971-
VectorType vecTy = op.getRes().getType();
981+
Type type = op.getRes().getType();
972982

973-
std::string funcName = getSIMDBlockManglingName(op, vecTy);
983+
std::string funcName = getSIMDBlockManglingName(op, type);
974984
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
975985
/*other=*/LLVM::ModRefInfo::NoModRef,
976986
/*argMem=*/LLVM::ModRefInfo::Ref,
977987
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
978988
auto funcAttrs = noUnwindWillReturnAttrs;
979989
funcAttrs.memEffectsAttr = memAttr;
980990
LLVM::CallOp call = createDeviceFunctionCall(
981-
rewriter, funcName, vecTy, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {});
991+
rewriter, funcName, type, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {});
982992

983993
rewriter.replaceOp(op, call.getResult());
984994
return success();
@@ -995,9 +1005,9 @@ struct TritonSIMDBlockWriteLowering
9951005
ConversionPatternRewriter &rewriter) const override {
9961006
MLIRContext *ctx = rewriter.getContext();
9971007
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
998-
VectorType vecTy = op.getVal().getType();
1008+
Type type = op.getVal().getType();
9991009

1000-
std::string funcName = getSIMDBlockManglingName(op, vecTy);
1010+
std::string funcName = getSIMDBlockManglingName(op, type);
10011011

10021012
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
10031013
/*other=*/LLVM::ModRefInfo::NoModRef,
@@ -1006,7 +1016,7 @@ struct TritonSIMDBlockWriteLowering
10061016
auto funcAttrs = noUnwindWillReturnAttrs;
10071017
funcAttrs.memEffectsAttr = memAttr;
10081018
LLVM::CallOp call = createDeviceFunctionCall(
1009-
rewriter, funcName, void_ty(ctx), {ptrTy, vecTy},
1019+
rewriter, funcName, void_ty(ctx), {ptrTy, type},
10101020
{op.getPtr(), op.getVal()}, {}, funcAttrs);
10111021

10121022
rewriter.replaceOp(op, call);

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -767,14 +767,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
767767
rewriter.replaceOp(op, result);
768768
}
769769

770-
VectorType
771-
getTypeForSubGroupTranspose(ArrayRef<Value> inVals,
772-
ConversionPatternRewriter &rewriter) const {
773-
auto elementTy = cast<IntegerType>(inVals.front().getType());
774-
return elementTy.getWidth() <= 16 ? vec_ty(elementTy, 16)
775-
: vec_ty(elementTy, 8);
776-
}
777-
778770
Value wrapInVector(Location loc, VectorType type, ArrayRef<Value> values,
779771
ConversionPatternRewriter &rewriter) const {
780772
assert(type.getShape()[0] == values.size() && "Size mismatch");
@@ -800,18 +792,18 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
800792
performSubGroupTranspose(Location loc, ArrayRef<Value> inVals,
801793
ConversionPatternRewriter &rewriter,
802794
bool isContiguous) const {
803-
VectorType opType = getTypeForSubGroupTranspose(inVals, rewriter);
795+
Type elementType = inVals.front().getType();
804796
auto mod = rewriter.getInsertionPoint()->getParentOfType<ModuleOp>();
805-
unsigned vecWidth = opType.getShape()[0];
806797

807798
Value smemBase = LLVM::intel::getSharedMemoryBase(
808799
loc, rewriter, targetInfo, &*rewriter.getInsertionPoint());
809800
Type ptrType = smemBase.getType();
810-
811-
int numElements = inVals.size();
801+
int numRows = inVals.size();
812802
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
813-
int offset = threadsPerWarp;
803+
int rowLength = threadsPerWarp + 1;
814804
Type offsetType = getTypeConverter()->getIndexType();
805+
Value subGroupOffset =
806+
int_val(offsetType.getIntOrFloatBitWidth(), rowLength * numRows);
815807
Value subGroupId = getValueOrCreateCastToIndexLike(
816808
rewriter, loc, offsetType,
817809
rewriter.create<mlir::gpu::SubgroupIdOp>(
@@ -820,42 +812,40 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
820812
rewriter, loc, offsetType,
821813
rewriter.create<mlir::gpu::LaneIdOp>(loc,
822814
/*upper_bound=*/IntegerAttr{}));
823-
int wiStrideNum = isContiguous ? numElements : threadsPerWarp;
824-
Value wiStride =
825-
rewriter.create<LLVM::ConstantOp>(loc, offsetType, wiStrideNum);
826-
Value sgStride = rewriter.create<LLVM::ConstantOp>(
827-
loc, offsetType, threadsPerWarp * numElements);
828-
Value subGroupOffset = mul(sgStride, subGroupId);
829-
Type elementType = opType.getElementType();
830815
Value subGroupBasePtr = gep(ptrType, elementType, smemBase,
831816
ValueRange{subGroupOffset}, /*inbounds=*/true);
832817
Value base = subGroupBasePtr;
833-
// Store in matrix, transposed
834-
for (ArrayRef<Value> vals = inVals; !vals.empty();
835-
vals = vals.drop_front(vecWidth)) {
836-
ArrayRef<Value> curr = vals.take_front(vecWidth);
837-
Value vec = wrapInVector(loc, opType, curr, rewriter);
838-
rewriter.create<TritonGEN::SIMDBlockWriteOp>(loc, base, vec);
839-
base = gep(base.getType(), opType, base, ArrayRef<LLVM::GEPArg>{offset},
818+
for (Value val : inVals) {
819+
rewriter.create<TritonGEN::SIMDBlockWriteOp>(loc, base, val);
820+
base = gep(base.getType(), elementType, base,
821+
ArrayRef<LLVM::GEPArg>{rowLength},
840822
/*inbounds=*/true);
841823
}
842824

843-
// Load from matrix, non-trasposed.
844-
// As per SIMD block semantics, we have stored the elements in a matrix of
845-
// `Nxsub_group_size` size, so we need to load back in blocks of
846-
// `sub_group_size` (`N/sub_group_size` loads).
847-
Value workItemOffset = mul(wiStride, subGroupLocalId);
825+
int32_t numContiguous = isContiguous ? inVals.size() / threadsPerWarp : 1;
826+
int32_t workItemStride =
827+
isContiguous ? rowLength : rowLength * threadsPerWarp;
828+
Value workItemOffset =
829+
mul(subGroupLocalId, int_val(offsetType.getIntOrFloatBitWidth(),
830+
numContiguous * rowLength));
848831
Value workItemBasePtr =
849832
gep(ptrType, elementType, subGroupBasePtr, ValueRange{workItemOffset},
850833
/*inbounds=*/true);
851-
SmallVector<Value> transposedVecs;
852-
Type loadTy = vec_ty(opType.getElementType(), wiStrideNum);
853-
for (std::size_t i = 0, n = inVals.size(); i < n; i += wiStrideNum) {
854-
transposedVecs.push_back(load(loadTy, workItemBasePtr));
855-
workItemBasePtr = gep(ptrType, loadTy, workItemBasePtr,
856-
ArrayRef<LLVM::GEPArg>{offset}, /*inbounds=*/true);
834+
int32_t rowsPerThread = numRows / threadsPerWarp;
835+
SmallVector<Value> outputVals;
836+
for (int i = 0; i < rowsPerThread; ++i) {
837+
for (int j = 0; j < threadsPerWarp; ++j) {
838+
outputVals.push_back(load(elementType, workItemBasePtr));
839+
workItemBasePtr =
840+
gep(workItemBasePtr.getType(), elementType, workItemBasePtr,
841+
ArrayRef<LLVM::GEPArg>{1}, /*inbounds=*/true);
842+
}
843+
workItemBasePtr =
844+
gep(workItemBasePtr.getType(), elementType, workItemBasePtr,
845+
ArrayRef<LLVM::GEPArg>{workItemStride - threadsPerWarp},
846+
/*inbounds=*/true);
857847
}
858-
return unwrapFromVectors(loc, transposedVecs, rewriter);
848+
return outputVals;
859849
}
860850

861851
void performUnbroadcast(ConvertLayoutOp op, const LinearLayout &srcLayout,

0 commit comments

Comments
 (0)