diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 727393dbdc..dfff5578ff 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -736,6 +736,10 @@ SmallVector unpackLLVector(Location loc, Value llvmVec, Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter) { assert(vals.size() > 0); + // Return scalar directly instead of creating a 1-element vector + if (vals.size() == 1) { + return vals[0]; + } auto vecType = vec_ty(vals[0].getType(), vals.size()); auto b = TritonLLVMOpBuilder(loc, rewriter); Value vec = b.undef(vecType);