diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 978111e8f547..df7e5dc2592b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -154,6 +154,19 @@ struct JoinOpConversion : public ConvertOpToLLVMPattern { unpackLLElements(loc, adaptor.getRhs(), rewriter); assert(lhsVals.size() == rhsVals.size()); SmallVector joinedVals; + if (isa(op.getLhs().getDefiningOp()) && + resultTy.getElementTypeBitWidth() == 16) { + for (int i = 0; i < lhsVals.size(); i += 2) { + joinedVals.push_back(lhsVals[i]); + joinedVals.push_back(lhsVals[i + 1]); + joinedVals.push_back(rhsVals[i]); + joinedVals.push_back(rhsVals[i + 1]); + } + Value ret = + packLLElements(loc, typeConverter, joinedVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } for (int i = 0; i < lhsVals.size(); i++) { joinedVals.push_back(lhsVals[i]); joinedVals.push_back(rhsVals[i]);