@@ -667,7 +667,7 @@ struct PrefetchOpConversion
667667 Value ptr = op.getPtr ();
668668 auto ptrType = cast<PointerType>(ptr.getType ());
669669 auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
670- Type eltTy = tensorType.getElementType ();
670+ Type eltTy = getTypeConverter ()-> convertType ( tensorType.getElementType () );
671671 const ArrayRef<int64_t > shapeRef = tensorType.getShape ();
672672 SmallVector<int64_t > tensorShape{shapeRef.begin (), shapeRef.end ()};
673673
@@ -882,7 +882,7 @@ struct PrefetchOpConversion
882882 mlir::ceil<int64_t >(shardTensorShape[0 ], prefetchShape[0 ]),
883883 mlir::ceil<int64_t >(shardTensorShape[1 ], prefetchShape[1 ])};
884884
885- Type eltTy = tensorType.getElementType ();
885+ Type eltTy = getTypeConverter ()-> convertType ( tensorType.getElementType () );
886886 unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
887887 unsigned tileWidthInElem = prefetchShape[1 ];
888888 unsigned tileHeightInElem = prefetchShape[0 ];
@@ -1061,7 +1061,7 @@ struct LoadOpToBlockIOConversion
10611061 " Only row_major or column_major is allowed" );
10621062 const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
10631063
1064- Type eltTy = tensorType.getElementType ();
1064+ Type eltTy = getTypeConverter ()-> convertType ( tensorType.getElementType () );
10651065 unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
10661066
10671067 auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout (
@@ -1893,7 +1893,7 @@ struct LoadOpToBlockIOConversion
18931893
18941894 // Step 2: Right now we only support DPAS related layout to simplify the
18951895 // lowering.
1896- Type eltTy = tensorType.getElementType ();
1896+ Type eltTy = getTypeConverter ()-> convertType ( tensorType.getElementType () );
18971897 unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth ();
18981898 DpasEncodingAttr dpasLayout = getDpasLayout (tensorType);
18991899 const ArrayRef<int64_t > tensorShape = tensorType.getShape ();
0 commit comments