@@ -827,46 +827,34 @@ struct LoadOpConversion
827827 rewriteTensorPointerLoad (op, adaptor, rewriter).succeeded ())
828828 return success ();
829829
830- auto loc = op->getLoc ();
831- auto typeConverter = getTypeConverter ();
832- auto *ctx = rewriter.getContext ();
830+ Location loc = op->getLoc ();
831+ TritonIntelGPUToLLVMTypeConverter *typeConverter = getTypeConverter ();
832+ MLIRContext *ctx = rewriter.getContext ();
833+ Value ptr = op.getPtr ();
834+ Value mask = op.getMask ();
835+ Value llMask = adaptor.getMask ();
833836
834837 // Determine the vectorization size
835838 Type valueElemTy =
836839 typeConverter->convertType (getElementTypeOrSelf (op.getType ()));
837840 unsigned numElems = getTotalElemsPerThread (op.getType ());
838- unsigned vec = 1 ;
841+ unsigned vec = getVectorSize (ptr);
842+ if (llMask)
843+ vec = std::min<size_t >(vec, getMaskAlignment (mask));
839844
840845 SmallVector<Value> ptrElems, maskElems, otherElems;
841846 bool otherIsSplatConstInt = false ;
842847 int64_t splatVal = 0 ;
843848
844- if (isTensorPointerType (op.getPtr ().getType ())) {
845- Value ptr = op.getPtr ();
846- Value mask = op.getMask ();
847- Value llMask = adaptor.getMask ();
848- vec = getVectorSize (ptr);
849- if (llMask)
850- vec = std::min<size_t >(vec, getMaskAlignment (mask));
851-
852- Type resultType = op.getType ();
853- auto tensorType = cast<RankedTensorType>(resultType);
849+ if (isTensorPointerType (ptr.getType ())) {
850+ auto tensorType = cast<RankedTensorType>(op.getType ());
854851 std::tie (ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr (
855852 loc, adaptor.getPtr (), tensorType, valueElemTy, rewriter,
856853 op.getBoundaryCheck (), op.getPadding ());
857854 } else {
858- // original values
859- Value ptr = op.getPtr ();
860855 Value other = op.getOther ();
861- Value mask = op.getMask ();
862-
863- // adaptor values
864856 Value llPtr = adaptor.getPtr ();
865- Value llMask = adaptor.getMask ();
866857 Value llOther = adaptor.getOther ();
867- vec = getVectorSize (ptr);
868- if (llMask)
869- vec = std::min<size_t >(vec, getMaskAlignment (mask));
870858
871859 // Get the LLVM values for pointers
872860 ptrElems = unpackLLElements (loc, llPtr, rewriter);
@@ -1141,39 +1129,35 @@ struct StoreOpConversion
11411129 return success ();
11421130
11431131 Location loc = op->getLoc ();
1132+ TritonIntelGPUToLLVMTypeConverter *typeConverter = getTypeConverter ();
11441133 MLIRContext *ctx = rewriter.getContext ();
1145- Value value = op.getValue ();
1146-
11471134 Value ptr = op.getPtr ();
1135+ Value mask = op.getMask ();
1136+ Value llMask = adaptor.getMask ();
1137+
1138+ // Determine the vectorization size
1139+ Value value = op.getValue ();
11481140 Type valueTy = value.getType ();
11491141 Type valueElemTy =
11501142 typeConverter->convertType (getElementTypeOrSelf (valueTy));
1151- SmallVector<Value> ptrElems;
1152- SmallVector<Value> maskElems;
1153- unsigned vec = 1 ;
1143+ SmallVector<Value> ptrElems, maskElems;
1144+ unsigned vec = getVectorSize (ptr);
1145+ if (llMask)
1146+ vec = std::min<size_t >(vec, getMaskAlignment (mask));
11541147
11551148 if (isTensorPointerType (ptr.getType ())) {
1156- // fallback to scatter store.
11571149 auto tensorType = cast<RankedTensorType>(valueTy);
11581150 SmallVector<Value> dummyOther;
11591151 std::tie (ptrElems, maskElems, dummyOther) = convertBlockPtrToTensorOfPtr (
11601152 loc, adaptor.getPtr (), tensorType, valueElemTy, rewriter,
11611153 op.getBoundaryCheck ());
11621154 } else {
11631155 Value llPtr = adaptor.getPtr ();
1164- Value llMask = adaptor.getMask ();
1165-
1166- vec = getVectorSize (ptr);
11671156
11681157 ptrElems = unpackLLElements (loc, llPtr, rewriter);
11691158
1170- // Determine the vectorization size
11711159 if (llMask) {
1172- Value mask = op.getMask ();
11731160 maskElems = unpackLLElements (loc, llMask, rewriter);
1174-
1175- unsigned maskAlign = getMaskAlignment (mask);
1176- vec = std::min (vec, maskAlign);
11771161 }
11781162 }
11791163
@@ -1183,7 +1167,7 @@ struct StoreOpConversion
11831167 assert (!maskElems.size () ||
11841168 valueElems.size () == maskElems.size () && " Mask size mismatch" );
11851169
1186- Value mask = redundantDataMask (valueTy, rewriter, loc, targetInfo);
1170+ mask = redundantDataMask (valueTy, rewriter, loc, targetInfo);
11871171 const size_t dtsize =
11881172 std::max<int >(1 , valueElemTy.getIntOrFloatBitWidth () / 8 );
11891173 const size_t valueElemNBits = dtsize * 8 ;
0 commit comments