Skip to content

Commit 35c6c7c

Browse files
davidberard98bertmaher
authored andcommitted
Revert "[BACKEND] Optimize code generation for load with other arg (triton-lang#4582)"
This reverts commit 78af5c9.
1 parent 3e00b0e commit 35c6c7c

File tree

2 files changed

+32
-38
lines changed

2 files changed

+32
-38
lines changed

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
1717
// CHECK-LABEL: basic_load
1818
tt.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
1919
// CHECK: llvm.inline_asm
20-
// CHECK-SAME: mov.u32 $0, $1;
21-
// CHECK-SAME: @$3 ld.global.b32 { $0 }, [ $2 + 0 ];", "=r,r,l,b"
2220
// CHECK: llvm.inline_asm
2321
%1 = tt.load %a_ptr_init, %cst, %cst_0 : tensor<256x!tt.ptr<f32>, #blocked0>
2422
tt.return

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -262,47 +262,12 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
262262

263263
// prepare asm operands
264264
auto *dstsOpr = ptxBuilder.newListOperand();
265-
// If there is a `other` value, use it to init.
266-
bool init = other == nullptr;
267265
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
268266
auto *opr = ptxBuilder.newOperand(writeConstraint,
269-
init); // =r operations
267+
/*init=*/true); // =r operations
270268
dstsOpr->listAppend(opr);
271269
}
272270

273-
if (other) {
274-
for (size_t ii = 0; ii < nWords; ++ii) {
275-
// PTX doesn't support mov.u8, so we need to use mov.u16
276-
PTXInstr &mov =
277-
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
278-
279-
size_t size = width / valueElemNBits;
280-
281-
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
282-
Value v = undef(vecTy);
283-
for (size_t s = 0; s < size; ++s) {
284-
Value falseVal = otherElems[vecStart + ii * size + s];
285-
Value sVal = createIndexAttrConstant(
286-
rewriter, loc, typeConverter->getIndexType(), s);
287-
v = insert_element(vecTy, v, falseVal, sVal);
288-
}
289-
v = bitcast(v, IntegerType::get(getContext(), width));
290-
291-
PTXInstr::Operand *opr{};
292-
293-
if (otherIsSplatConstInt) {
294-
int64_t replicatedSplatVal = 0;
295-
for (size_t s = 0; s < movWidth; s += valueElemNBits) {
296-
replicatedSplatVal |= splatVal << s;
297-
}
298-
opr = ptxBuilder.newConstantOperand(replicatedSplatVal);
299-
} else
300-
opr = ptxBuilder.newOperand(v, readConstraint);
301-
302-
mov(dstsOpr->listGet(ii), opr);
303-
}
304-
}
305-
306271
auto *addrOpr =
307272
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
308273

@@ -331,6 +296,37 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
331296
else
332297
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");
333298

299+
if (other) {
300+
for (size_t ii = 0; ii < nWords; ++ii) {
301+
// PTX doesn't support mov.u8, so we need to use mov.u16
302+
PTXInstr &mov =
303+
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
304+
305+
size_t size = width / valueElemNBits;
306+
307+
auto vecTy = LLVM::getFixedVectorType(valueElemTy, size);
308+
Value v = undef(vecTy);
309+
for (size_t s = 0; s < size; ++s) {
310+
Value falseVal = otherElems[vecStart + ii * size + s];
311+
Value sVal = createIndexAttrConstant(
312+
rewriter, loc, typeConverter->getIndexType(), s);
313+
v = insert_element(vecTy, v, falseVal, sVal);
314+
}
315+
v = bitcast(v, IntegerType::get(getContext(), width));
316+
317+
PTXInstr::Operand *opr{};
318+
319+
if (otherIsSplatConstInt) {
320+
for (size_t s = 0; s < 32; s += valueElemNBits)
321+
splatVal |= splatVal << valueElemNBits;
322+
opr = ptxBuilder.newConstantOperand(splatVal);
323+
} else
324+
opr = ptxBuilder.newOperand(v, readConstraint);
325+
326+
mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
327+
}
328+
}
329+
334330
// Create inline ASM signature
335331
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
336332
Type retTy = retTys.size() > 1

0 commit comments

Comments
 (0)