@@ -149,7 +149,6 @@ static tt::LoadOp findUsedLoad(Value val) {
149149}
150150
151151static bool getTransposeFlagFromValue (Value val) {
152- bool isTransposed = false ;
153152 Value loadPtr = val;
154153 // backward: from dot operands to tt.load
155154 if (llvm::any_of (val.getUsers (),
@@ -167,23 +166,26 @@ static bool getTransposeFlagFromValue(Value val) {
167166 if (auto blockArg = dyn_cast<BlockArgument>(loadPtr)) {
168167 unsigned argIdx = blockArg.getArgNumber ();
169168 if (auto loopLikeOp = dyn_cast<LoopLikeOpInterface>(
170- blockArg.getParentBlock ()->getParentOp ())) {
171- auto inits = llvm::to_vector (loopLikeOp.getInits ());
172- if (auto glueOp = inits[argIdx - 1 ].getDefiningOp <ttgi::GlueOp>()) {
173- if (auto tempPtr =
174- glueOp.getOperands ()[0 ].getDefiningOp <tt::MakeTensorPtrOp>()) {
175- loadPtr = tempPtr.getResult ();
176- }
177- }
178- }
169+ blockArg.getParentBlock ()->getParentOp ()))
170+ loadPtr = loopLikeOp.getInits ()[argIdx - 1 ];
171+ }
172+
173+ if (auto glueOp = loadPtr.getDefiningOp <ttgi::GlueOp>()) {
174+ if (isa_and_present<tt::MakeTensorPtrOp, tt::AdvanceOp>(
175+ glueOp.getOperands ()[0 ].getDefiningOp ()))
176+ loadPtr = glueOp.getOperands ()[0 ];
179177 }
180178
181179 if (auto tensorPtr = loadPtr.getDefiningOp <tt::MakeTensorPtrOp>()) {
182180 ArrayRef<int32_t > order = tensorPtr.getOrder ();
183181 auto rank = order.size ();
184- isTransposed = (order[rank - 2 ] != 1 );
182+ return (order[rank - 2 ] != 1 );
185183 }
186- return isTransposed;
184+
185+ if (auto advOp = loadPtr.getDefiningOp <tt::AdvanceOp>())
186+ return getTransposeFlagFromValue (advOp.getPtr ());
187+
188+ return false ;
187189}
188190
189191static void rewriteLoadWithSLM (ModuleOp &m, DenseSet<Value> &dotWithSLMOperands,
@@ -275,6 +277,14 @@ class MatchTargetSizePass
275277 MLIRContext *ctx = &getContext ();
276278 ModuleOp m = getOperation ();
277279
280+ // By default, tritongpu are lowered to simt mode (threads-per-warp=16)
281+ // instead of simd mode (threads-per-warp=1).
282+ // FIXME: force threads-per-warp=16 in simt(this should be done via an
283+ // analysis designed to determine whether the kernel contains tt.dot
284+ // operations that use block pointers).
285+ m->setAttr (" triton_gpu.threads-per-warp" ,
286+ IntegerAttr::get (IntegerType::get (ctx, 32 ), 16 ));
287+
278288 Workload workload = Workload::None;
279289 m.walk ([&](scf::ForOp forOp) {
280290 if (Attribute attr = forOp->getAttr (AttrWorkloadName))
@@ -352,14 +362,6 @@ class MatchTargetSizePass
352362 canonicalize ();
353363 LLVM_DEBUG (llvm::dbgs () << " Module after canonicalization:\n "
354364 << m << " \n\n " );
355-
356- // By default, tritongpu are lowered to simt mode (threads-per-warp=16)
357- // instead of simd mode (threads-per-warp=1).
358- // FIXME: force threads-per-warp=16 in simt(this should be done via an
359- // analysis designed to determine whether the kernel contains tt.dot
360- // operations that use block pointers).
361- m->setAttr (" triton_gpu.threads-per-warp" ,
362- IntegerAttr::get (IntegerType::get (ctx, 32 ), 16 ));
363365 }
364366
365367private:
@@ -379,8 +381,8 @@ class MatchTargetSizePass
379381 bool isTransposed) const ;
380382
381383 std::tuple<SmallVector<int64_t >, Type, SmallVector<int64_t >>
382- getSubTypeAndShape (Type type, bool isTransposed = false ,
383- bool useSLM = false ) const ;
384+ getSubTypeAndShape (Type type, bool isTransposed = false , bool useSLM = false ,
385+ bool keepEncoding = false ) const ;
384386
385387 Value getSubVal (Operation *op, Value val, ArrayRef<int64_t > srcOffset,
386388 ArrayRef<int64_t > dstSize);
@@ -753,7 +755,7 @@ MatchTargetSizePass::getSubOpSize(RankedTensorType type,
753755// / return [shape, subType, subSize] for a tensor (or pointer to tensor)
754756std::tuple<SmallVector<int64_t >, Type, SmallVector<int64_t >>
755757MatchTargetSizePass::getSubTypeAndShape (Type type, bool isTransposed,
756- bool useSLM) const {
758+ bool useSLM, bool keepEncoding ) const {
757759 if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
758760 Attribute layout = tensorType.getEncoding ();
759761 assert (layout && " Expecting a valid layout" );
@@ -771,15 +773,16 @@ MatchTargetSizePass::getSubTypeAndShape(Type type, bool isTransposed,
771773 subSize[1 ] = std::min (subSize[1 ], shape[1 ]);
772774 }
773775
774- auto subType = RankedTensorType::get (
775- subSize, tensorType.getElementType () /* no encoding*/ );
776+ auto subType = RankedTensorType::get (subSize, tensorType.getElementType (),
777+ keepEncoding ? tensorType.getEncoding ()
778+ : Attribute{});
776779 return {shape, subType, subSize};
777780 }
778781
779782 if (auto ptrType = dyn_cast<tt::PointerType>(type)) {
780783 Type pointeeType = ptrType.getPointeeType ();
781784 auto [shape, subType, subSize] =
782- getSubTypeAndShape (pointeeType, isTransposed, useSLM);
785+ getSubTypeAndShape (pointeeType, isTransposed, useSLM, keepEncoding );
783786 auto newType = tt::PointerType::get (subType, ptrType.getAddressSpace ());
784787 return {shape, newType, subSize};
785788 }
@@ -1186,8 +1189,11 @@ void MatchTargetSizePass::transformBroadcastOp(ttgi::BroadcastOp op) {
11861189 glue = b.create <ttgi::GlueOp>(loc, resType, ops);
11871190 } else if (srcDim0 == 1 && srcDim1 == resDim1) {
11881191 // Handle row-vector broadcasts, e.g. 1x64 --> 16x64.
1192+ // This kind of broadcast requires that the tensor type is kept intact by
1193+ // SIMT lowering, hence propagate the encoding here.
11891194 auto subRowVecTy =
1190- RankedTensorType::get ({1 , tType.getShape ()[1 ]}, tType.getElementType ());
1195+ RankedTensorType::get ({1 , tType.getShape ()[1 ]}, tType.getElementType (),
1196+ srcType.getEncoding ());
11911197
11921198 // How many extracts do we need to cover the width of the input tensor?
11931199 unsigned nExtracts = srcDim1 / dstDim1;
@@ -1222,9 +1228,10 @@ void MatchTargetSizePass::transformMakeRangeOp(tt::MakeRangeOp op) {
12221228
12231229 unsigned start = op.getStart ();
12241230 unsigned end = op.getEnd ();
1225- assert (start == 0 && end % subgroupSize == 0 && " Unsupported range" );
1231+ assert (start == 0 && (end <= subgroupSize || end % subgroupSize == 0 ) &&
1232+ " Unsupported range" );
12261233
1227- if (end = = subgroupSize)
1234+ if (end < = subgroupSize)
12281235 // nothing to do
12291236 return ;
12301237
@@ -1240,6 +1247,7 @@ void MatchTargetSizePass::transformMakeRangeOp(tt::MakeRangeOp op) {
12401247 Location loc = op.getLoc ();
12411248 RankedTensorType origTy = op.getType ();
12421249 Type elemTy = origTy.getElementType ();
1250+ // Propagate encoding to keep tensor during SIMT lowering.
12431251 auto subRangeTy =
12441252 RankedTensorType::get ({subgroupSize}, elemTy, origTy.getEncoding ());
12451253 auto subRange = b.create <tt::MakeRangeOp>(loc, subRangeTy, 0 , subgroupSize);
@@ -1310,8 +1318,16 @@ void MatchTargetSizePass::transformGenericOp(Operation *op) {
13101318 cast<tt::PointerType>(load.getPtr ().getType ()).getAddressSpace ();
13111319 useSLM = (ptrAS == TritonGEN::TritonGENMemorySpace::kWorkgroup );
13121320 }
1321+
1322+ // Keep encoding on certain tensors to leave them untouched during SIMT
1323+ // lowering. Currently, this is required for "row vectors" (= `tensor<1xN>`).
1324+ bool keepEncoding = false ;
1325+ if (auto tensorType = dyn_cast<RankedTensorType>(type)) {
1326+ ArrayRef<int64_t > shape = tensorType.getShape ();
1327+ keepEncoding = shape.size () == 2 && shape[0 ] == 1 && shape[1 ] > 1 ;
1328+ }
13131329 auto [shape, subType, subSize] =
1314- getSubTypeAndShape (type, isTransposed, useSLM);
1330+ getSubTypeAndShape (type, isTransposed, useSLM, keepEncoding );
13151331
13161332 unsigned dim = shape.size ();
13171333 OpBuilder b (op);
@@ -1328,8 +1344,8 @@ void MatchTargetSizePass::transformGenericOp(Operation *op) {
13281344 [&](Value operand) {
13291345 Type type = operand.getType ();
13301346 if (isa<tt::PointerType, RankedTensorType>(type)) {
1331- Type subOpndType = std::get<1 >(
1332- getSubTypeAndShape ( type, isTransposed, useSLM));
1347+ Type subOpndType = std::get<1 >(getSubTypeAndShape (
1348+ type, isTransposed, useSLM, keepEncoding ));
13331349 Value newOp = b.create <ttgi::ExtractOp>(
13341350 loc, subOpndType, operand, idx);
13351351 return newOp;
0 commit comments