Skip to content

Commit 6536edb

Browse files
authored
[NFI]: Refactor LoadStoreOpToLLVM.cpp (#2623)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent c0987c7 commit 6536edb

File tree

1 file changed

+31
-41
lines changed

1 file changed

+31
-41
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
#include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h"
1313
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
14-
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
1514

1615
using namespace mlir;
1716
using namespace mlir::triton;
@@ -194,7 +193,7 @@ struct LoadStoreConversionBase {
194193
ArrayRef<int32_t> boundaryCheck = {},
195194
std::optional<PaddingOption> padding = std::nullopt) const {
196195

197-
auto rank = tensorType.getRank();
196+
size_t rank = tensorType.getRank();
198197
// The block pointer struct is expected to have the following layout:
199198
// Struct {
200199
// Value offset[rank];
@@ -818,36 +817,31 @@ struct LoadOpConversion
818817
LogicalResult
819818
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
820819
ConversionPatternRewriter &rewriter) const override {
821-
auto loc = op->getLoc();
820+
if (isTensorPointerType(op.getPtr().getType()))
821+
if (rewriteTensorPointerLoad(op, adaptor, rewriter).succeeded())
822+
return success();
823+
824+
Location loc = op->getLoc();
822825
auto typeConverter = getTypeConverter();
823-
auto *ctx = rewriter.getContext();
826+
MLIRContext *ctx = rewriter.getContext();
824827

825828
// Determine the vectorization size
826829
Type valueElemTy =
827830
typeConverter->convertType(getElementTypeOrSelf(op.getType()));
828831
unsigned numElems = getTotalElemsPerThread(op.getType());
829832
unsigned vec = 1;
830833

831-
SmallVector<Value> ptrElems;
832-
SmallVector<Value> maskElems;
833-
834+
SmallVector<Value> ptrElems, maskElems, otherElems;
834835
bool otherIsSplatConstInt = false;
835836
int64_t splatVal = 0;
836-
SmallVector<Value> otherElems;
837837

838838
if (isTensorPointerType(op.getPtr().getType())) {
839-
if (rewriteTensorPointerLoad(op, adaptor, rewriter).succeeded()) {
840-
return success();
841-
} else {
842-
// TODO: (johnlu) set the vector size > 1; Need to prove the memory is
843-
// contiguous on the fast changing dim when fallback to gather load.
844-
Type resultType = op.getType();
845-
auto tensorType = cast<RankedTensorType>(resultType);
846-
std::tie(ptrElems, maskElems, otherElems) =
847-
convertBlockPtrToTensorOfPtr(
848-
loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter,
849-
op.getBoundaryCheck(), op.getPadding());
850-
}
839+
// TODO: (johnlu) set the vector size > 1; Need to prove the memory is
840+
// contiguous on the fast changing dim when fallback to gather load.
841+
auto tensorType = cast<RankedTensorType>(op.getType());
842+
std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr(
843+
loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter,
844+
op.getBoundaryCheck(), op.getPadding());
851845
} else {
852846
// original values
853847
Value ptr = op.getPtr();
@@ -922,7 +916,7 @@ struct LoadOpConversion
922916
for (size_t s = 0; s < size; ++s) {
923917
Value falseVal = otherElems[vecStart + ii * size + s];
924918
Value sVal = createIndexAttrConstant(
925-
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
919+
rewriter, loc, typeConverter->getIndexType(), s);
926920
v = insert_element(vecTy, v, falseVal, sVal);
927921
}
928922
v = bitcast(v, IntegerType::get(ctx, width));
@@ -934,7 +928,7 @@ struct LoadOpConversion
934928
}
935929

936930
Value iiVal = createIndexAttrConstant(
937-
rewriter, loc, this->getTypeConverter()->getIndexType(), ii);
931+
rewriter, loc, typeConverter->getIndexType(), ii);
938932
if (nWords > 1) {
939933
other_ = insert_element(retTy, other_, v, iiVal);
940934
} else {
@@ -1129,31 +1123,27 @@ struct StoreOpConversion
11291123
LogicalResult
11301124
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
11311125
ConversionPatternRewriter &rewriter) const override {
1132-
auto loc = op->getLoc();
1133-
MLIRContext *ctx = rewriter.getContext();
1126+
if (isTensorPointerType(op.getPtr().getType()))
1127+
if (rewriteTensorPointerStore(op, adaptor, rewriter).succeeded())
1128+
return success();
11341129

1130+
Location loc = op->getLoc();
1131+
MLIRContext *ctx = rewriter.getContext();
11351132
Value ptr = op.getPtr();
11361133
Value value = op.getValue();
1137-
1138-
auto valueTy = value.getType();
1134+
Type valueTy = value.getType();
11391135
Type valueElemTy =
11401136
typeConverter->convertType(getElementTypeOrSelf(valueTy));
1141-
SmallVector<Value> ptrElems;
1142-
SmallVector<Value> maskElems;
1137+
SmallVector<Value> ptrElems, maskElems;
11431138
unsigned vec = 1;
11441139

11451140
if (isTensorPointerType(ptr.getType())) {
1146-
if (rewriteTensorPointerStore(op, adaptor, rewriter).succeeded()) {
1147-
return success();
1148-
} else {
1149-
// fallback to scatter store.
1150-
auto tensorType = cast<RankedTensorType>(valueTy);
1151-
SmallVector<Value> dummyOther;
1152-
std::tie(ptrElems, maskElems, dummyOther) =
1153-
convertBlockPtrToTensorOfPtr(loc, adaptor.getPtr(), tensorType,
1154-
valueElemTy, rewriter,
1155-
op.getBoundaryCheck());
1156-
}
1141+
// fallback to scatter store.
1142+
auto tensorType = cast<RankedTensorType>(valueTy);
1143+
SmallVector<Value> dummyOther;
1144+
std::tie(ptrElems, maskElems, dummyOther) = convertBlockPtrToTensorOfPtr(
1145+
loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter,
1146+
op.getBoundaryCheck());
11571147
} else {
11581148
Value llPtr = adaptor.getPtr();
11591149
Value llMask = adaptor.getMask();
@@ -1561,8 +1551,8 @@ struct AtomicRMWOpConversion
15611551

15621552
rmwVal = bitcast(rmwVal, valueElemTy);
15631553

1564-
// Align pointer by 4 bytes by zeroing lower address bits. Atomically read
1565-
// a vector of two fp16 values as a single i32. The second lowest bit is
1554+
// Align pointer by 4 bytes by zeroing lower address bits. Atomically read a
1555+
// vector of two fp16 values as a single i32. The second lowest bit is
15661556
// extracted to later be used as an index to extract the required vector
15671557
// element.
15681558
assert(isa<LLVM::LLVMPointerType>(rmwPtr.getType()));

0 commit comments

Comments
 (0)