@@ -262,47 +262,12 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
262262
263263 // prepare asm operands
264264 auto *dstsOpr = ptxBuilder.newListOperand ();
265- // If there is a `other` value, use it to init.
266- bool init = other == nullptr ;
267265 for (size_t wordIdx = 0 ; wordIdx < nWords; ++wordIdx) {
268266 auto *opr = ptxBuilder.newOperand (writeConstraint,
269- init); // =r operations
267+ /* init= */ true ); // =r operations
270268 dstsOpr->listAppend (opr);
271269 }
272270
273- if (other) {
274- for (size_t ii = 0 ; ii < nWords; ++ii) {
275- // PTX doesn't support mov.u8, so we need to use mov.u16
276- PTXInstr &mov =
277- ptxBuilder.create <>(" mov" )->o (" u" + std::to_string (movWidth));
278-
279- size_t size = width / valueElemNBits;
280-
281- auto vecTy = LLVM::getFixedVectorType (valueElemTy, size);
282- Value v = undef (vecTy);
283- for (size_t s = 0 ; s < size; ++s) {
284- Value falseVal = otherElems[vecStart + ii * size + s];
285- Value sVal = createIndexAttrConstant (
286- rewriter, loc, typeConverter->getIndexType (), s);
287- v = insert_element (vecTy, v, falseVal, sVal );
288- }
289- v = bitcast (v, IntegerType::get (getContext (), width));
290-
291- PTXInstr::Operand *opr{};
292-
293- if (otherIsSplatConstInt) {
294- int64_t replicatedSplatVal = 0 ;
295- for (size_t s = 0 ; s < movWidth; s += valueElemNBits) {
296- replicatedSplatVal |= splatVal << s;
297- }
298- opr = ptxBuilder.newConstantOperand (replicatedSplatVal);
299- } else
300- opr = ptxBuilder.newOperand (v, readConstraint);
301-
302- mov (dstsOpr->listGet (ii), opr);
303- }
304- }
305-
306271 auto *addrOpr =
307272 ptxBuilder.newAddrOperand (ptrElems[vecStart], " l" , in_off);
308273
@@ -331,6 +296,37 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
331296 else
332297 ld (dstsOpr, addrOpr, evictOpr).predicate (pred, " b" );
333298
299+ if (other) {
300+ for (size_t ii = 0 ; ii < nWords; ++ii) {
301+ // PTX doesn't support mov.u8, so we need to use mov.u16
302+ PTXInstr &mov =
303+ ptxBuilder.create <>(" mov" )->o (" u" + std::to_string (movWidth));
304+
305+ size_t size = width / valueElemNBits;
306+
307+ auto vecTy = LLVM::getFixedVectorType (valueElemTy, size);
308+ Value v = undef (vecTy);
309+ for (size_t s = 0 ; s < size; ++s) {
310+ Value falseVal = otherElems[vecStart + ii * size + s];
311+ Value sVal = createIndexAttrConstant (
312+ rewriter, loc, typeConverter->getIndexType (), s);
313+ v = insert_element (vecTy, v, falseVal, sVal );
314+ }
315+ v = bitcast (v, IntegerType::get (getContext (), width));
316+
317+ PTXInstr::Operand *opr{};
318+
319+ if (otherIsSplatConstInt) {
320+ for (size_t s = 0 ; s < 32 ; s += valueElemNBits)
321+ splatVal |= splatVal << valueElemNBits;
322+ opr = ptxBuilder.newConstantOperand (splatVal);
323+ } else
324+ opr = ptxBuilder.newOperand (v, readConstraint);
325+
326+ mov (dstsOpr->listGet (ii), opr).predicateNot (pred, " b" );
327+ }
328+ }
329+
334330 // Create inline ASM signature
335331 SmallVector<Type> retTys (nWords, IntegerType::get (getContext (), width));
336332 Type retTy = retTys.size () > 1
0 commit comments