@@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
116116 // opIdx: 0 => a, 1 => b
117117 auto type = cast<triton::MemDescType>(v.getType ());
118118 SmallVector<int64_t > shape{type.getShape ().begin (), type.getShape ().end ()};
119- SmallVector<int64_t > offset{ 0 , 0 } ;
119+ SmallVector<int64_t > offset (shape. size () , 0 ) ;
120120 Type elementType = type.getElementType ();
121121
122122 // k => (prefetchWidth, k - prefetchWidth)
@@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
140140 type.getMemorySpace ()),
141141 v, offsetsVal);
142142
143+ // We need to assign kwidth to zero in the case where the parent layout is
144+ // Blocked, otherwise the verifier emits a failure. The parent layout is
145+ // Blocked only when Tensor Cores are disabled.
146+ int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
147+ ? 0
148+ : prefetchWidth / 8 ;
143149 auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get (
144- builder.getContext (), opIdx, dotEncoding, prefetchWidth / 8 );
150+ builder.getContext (), opIdx, dotEncoding, kwidth );
145151 Value prefetchSlice = builder.create <triton::gpu::LocalLoadOp>(
146152 v.getLoc (), RankedTensorType::get (shape, elementType, dotOperandEnc),
147153 newSmem);
@@ -190,6 +196,22 @@ LogicalResult Prefetcher::initialize() {
190196 break ;
191197 if (!op->getResult (0 ).hasOneUse ())
192198 break ;
199+ // Similar to issues faced in HoistLayoutConversion pattern in
200+ // OptimizeDotOperands.cpp, we can't propagate through type casts from
201+ // predicates as they aren't supported in Triton when encoded with dot_op
202+ // layout.
203+ if (isa<arith::UIToFPOp>(op)) {
204+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
205+ if (srcType.isInteger (1 ))
206+ break ;
207+ }
208+ // Propagation through ExpandDims is currently not supported. This blindly
209+ // replaces the encoding with dot encoding & but ExpandDims requires a
210+ // SliceEncoding. This could be rewritten to support it somehow, but I
211+ // don't think it's trivial & it's currently crashing.
212+ if (isa<ExpandDimsOp>(op)) {
213+ break ;
214+ }
193215 rets.push_back (op->getOperand (0 ));
194216 if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
195217 foundConvertFromShared = true ;
0 commit comments