Skip to content

Commit 86dd31f

Browse files
authored
Add tt.trans op support (#126)
1 parent 3e825ec commit 86dd31f

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

lib/Conversion/TritonGPUToSPIRV/ViewOpToSPIRV.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,30 +205,28 @@ struct ExpandDimsOpSPIRVConversion
205205
}
206206
};
207207

208-
#if 0
209-
struct TransOpConversion
210-
: public ConvertTritonGPUOpToLLVMPattern<triton::TransOp> {
211-
using ConvertTritonGPUOpToLLVMPattern<
212-
triton::TransOp>::ConvertTritonGPUOpToLLVMPattern;
208+
struct TransOpSPIRVConversion
209+
: public ConvertTritonGPUOpToSPIRVPattern<triton::TransOp> {
210+
using ConvertTritonGPUOpToSPIRVPattern<
211+
triton::TransOp>::ConvertTritonGPUOpToSPIRVPattern;
213212

214213
LogicalResult
215214
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
216215
ConversionPatternRewriter &rewriter) const override {
217216
Location loc = op->getLoc();
218217
auto srcSmemObj =
219-
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
218+
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
220219
SmallVector<Value> dstStrides = {srcSmemObj.strides[1],
221220
srcSmemObj.strides[0]};
222221
SmallVector<Value> dstOffsets = {srcSmemObj.offsets[1],
223222
srcSmemObj.offsets[0]};
224223
auto dstSmemObj =
225-
SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets);
224+
SharedMemoryObject(srcSmemObj.base, dstStrides, dstOffsets);
226225
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
227226
rewriter.replaceOp(op, retVal);
228227
return success();
229228
}
230229
};
231-
#endif
232230

233231
void populateViewOpToSPIRVPatterns(
234232
TritonGPUToSPIRVTypeConverter &typeConverter, mlir::MLIRContext *context,
@@ -245,4 +243,5 @@ void populateViewOpToSPIRVPatterns(
245243
mlir::spirv::checkOpSupported(computeCapability,
246244
"INTELConvertFToBF16Op"));
247245
patterns.add<CatOpSPIRVConversion>(typeConverter, context, benefit);
246+
patterns.add<TransOpSPIRVConversion>(typeConverter, context, benefit);
248247
}

0 commit comments

Comments
 (0)