Skip to content

Commit ce08f57

Browse files
[LoadStoreOpToLLVM] Use type converter to get tensor element type (#5073)
This PR is to address review comment from #5068 (comment). --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent f9885fb commit ce08f57

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ struct PrefetchOpConversion
667667
Value ptr = op.getPtr();
668668
auto ptrType = cast<PointerType>(ptr.getType());
669669
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
670-
Type eltTy = tensorType.getElementType();
670+
Type eltTy = getTypeConverter()->convertType(tensorType.getElementType());
671671
const ArrayRef<int64_t> shapeRef = tensorType.getShape();
672672
SmallVector<int64_t> tensorShape{shapeRef.begin(), shapeRef.end()};
673673

@@ -882,7 +882,7 @@ struct PrefetchOpConversion
882882
mlir::ceil<int64_t>(shardTensorShape[0], prefetchShape[0]),
883883
mlir::ceil<int64_t>(shardTensorShape[1], prefetchShape[1])};
884884

885-
Type eltTy = tensorType.getElementType();
885+
Type eltTy = getTypeConverter()->convertType(tensorType.getElementType());
886886
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
887887
unsigned tileWidthInElem = prefetchShape[1];
888888
unsigned tileHeightInElem = prefetchShape[0];
@@ -1061,7 +1061,7 @@ struct LoadOpToBlockIOConversion
10611061
"Only row_major or column_major is allowed");
10621062
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
10631063

1064-
Type eltTy = tensorType.getElementType();
1064+
Type eltTy = getTypeConverter()->convertType(tensorType.getElementType());
10651065
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
10661066

10671067
auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout(
@@ -1893,7 +1893,7 @@ struct LoadOpToBlockIOConversion
18931893

18941894
// Step 2: Right now we only support DPAS related layout to simplify the
18951895
// lowering.
1896-
Type eltTy = tensorType.getElementType();
1896+
Type eltTy = getTypeConverter()->convertType(tensorType.getElementType());
18971897
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
18981898
DpasEncodingAttr dpasLayout = getDpasLayout(tensorType);
18991899
const ArrayRef<int64_t> tensorShape = tensorType.getShape();

0 commit comments

Comments
 (0)