Skip to content

Commit 21565a3

Browse files
committed
hack me
Signed-off-by: dchigarev <[email protected]>
1 parent 496d41c commit 21565a3

File tree

3 files changed

+75
-15
lines changed

3 files changed

+75
-15
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,13 @@ Value getSmemVecAddrNEW(const LinearLayout &regLayout,
367367
// solution for all swizzled shared memory scenarios, including the edge case
368368
// mentioned above.
369369
if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1
370-
smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout,
370+
auto res = applyLinearLayout(loc, rewriter, regToSharedLayout,
371371
{{kRegister, regId},
372372
{kLane, laneId},
373373
{kWarp, warpId},
374-
{kBlock, blockId}})[0]
375-
.second;
374+
{kBlock, blockId}});
375+
std::cout << "linearLayRes.size(): " << res.size() << "\n";
376+
smemOffset = res[0].second;
376377
} else { // Case 2 -> rank-reduced swizzling
377378
assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2");
378379
assert(!sharedEnc.getHasLeadingOffset() &&
@@ -426,7 +427,7 @@ Value getSmemVecAddrNEW(const LinearLayout &regLayout,
426427
} // namespace
427428

428429

429-
bool getBoolFromEnv(const std::string& envVar, bool defaultValue = false) {
430+
static bool getBoolFromEnv(const std::string& envVar, bool defaultValue = false) {
430431
const char* value = std::getenv(envVar.c_str());
431432
if (value == nullptr) {
432433
return defaultValue; // Return default if the variable is not set
@@ -549,10 +550,18 @@ bool emitTransferBetweenRegistersAndSharedNEW(
549550
StringAttr kWarp = str_attr("warp");
550551

551552
auto shape = sharedTy.getShape();
553+
llvm::dbgs() << "registerTy enc\n";
554+
registerTy.dump();
555+
registerTy.getEncoding().dump();
556+
llvm::dbgs() << "shape: "; for (auto &el : shape) { llvm::dbgs() << el << " ";} llvm::dbgs() << "\n";
552557
LinearLayout regLayout =
553558
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
554559
printLinearThing(regLayout, "regLayout");
555560

561+
llvm::dbgs() << "sharedTy enc\n";
562+
sharedTy.dump();
563+
sharedTy.getEncoding().dump();
564+
llvm::dbgs() << "shape: "; for (auto &el : shape) { llvm::dbgs() << el << " ";} llvm::dbgs() << "\n";
556565
LinearLayout sharedLayout = triton::gpu::toLinearLayout(
557566
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
558567
printLinearThing(sharedLayout, "sharedLayout");
@@ -653,13 +662,30 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
653662
bool success = emitTransferBetweenRegistersAndShared(
654663
dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc,
655664
rewriter, target, [&](VectorType vecTy, Value vecAddr) {
656-
auto vecVal = load(vecTy, vecAddr);
657-
vecVal.setAlignment(vecTy.getNumElements() *
658-
elemLlvmTy.getIntOrFloatBitWidth() / 8);
659-
660-
for (int v = 0; v < vecTy.getNumElements(); v++) {
661-
ret.push_back(extract_element(elemLlvmTy, vecVal, i32_val(v)));
665+
if (vecTy.getNumElements() >= 64) {
666+
assert(vecTy.getNumElements() % 64 == 0);
667+
for (int i = 0; i < vecTy.getNumElements(); i+=64) {
668+
auto smallVecTy = vec_ty(elemLlvmTy, 64);
669+
auto vecAddrNew = gep(vecAddr.getType(), i32_ty, vecAddr, SmallVector<Value>({i32_val(i)}));
670+
auto vecVal = load(smallVecTy, vecAddrNew);
671+
vecVal.setAlignment(smallVecTy.getNumElements() *
672+
elemLlvmTy.getIntOrFloatBitWidth() / 8);
673+
674+
for (int v = 0; v < 64; v++) {
675+
ret.push_back(extract_element(elemLlvmTy, vecVal, i32_val(v)));
676+
}
677+
}
678+
679+
} else {
680+
auto vecVal = load(vecTy, vecAddr);
681+
vecVal.setAlignment(vecTy.getNumElements() *
682+
elemLlvmTy.getIntOrFloatBitWidth() / 8);
683+
684+
for (int v = 0; v < vecTy.getNumElements(); v++) {
685+
ret.push_back(extract_element(elemLlvmTy, vecVal, i32_val(v)));
686+
}
662687
}
688+
663689
});
664690
if (!success)
665691
llvm::report_fatal_error("Failed to emit transfer from shared to register");

python/triton/runtime/build.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,5 +93,13 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi
9393
if os.getenv("VERBOSE"):
9494
print(" ".join(cc_cmd))
9595

96-
subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
96+
result = subprocess.run(cc_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, text=True)
97+
98+
if result.returncode != 0:
99+
print(f"Error: Command failed with exit code {result.returncode}")
100+
if result.stderr:
101+
print("Error output:", result.stderr)
102+
103+
# breakpoint()
104+
# subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
97105
return so

third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ using namespace mlir;
2020
using namespace mlir::triton;
2121
using namespace mlir::triton::gpu;
2222

23+
static bool getBoolFromEnv(const std::string& envVar, bool defaultValue = false) {
24+
const char* value = std::getenv(envVar.c_str());
25+
if (value == nullptr) {
26+
return defaultValue; // Return default if the variable is not set
27+
}
28+
29+
std::string strValue(value);
30+
for (char& c : strValue) c = std::tolower(c); // Convert to lowercase
31+
32+
return (strValue == "1" || strValue == "true" || strValue == "yes" || strValue == "on");
33+
}
34+
2335
// blocked -> shared.
2436
// Swizzling in shared memory to avoid bank conflict. Normally used for
2537
// A/B operands of dots.
@@ -78,11 +90,14 @@ struct LocalAllocOpConversion
7890
LogicalResult
7991
matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor,
8092
ConversionPatternRewriter &rewriter) const override {
81-
llvm::dbgs() << "\n\n===LocalAllocOpConversion BEFORE===\n";
82-
op.dump();
83-
for (auto& x : *(op->getParentRegion())) {
84-
x.dump();
93+
if (getBoolFromEnv("TR_LONG_IR")) {
94+
llvm::dbgs() << "\n\n===LocalAllocOpConversion BEFORE===\n";
95+
op.dump();
96+
for (auto& x : *(op->getParentRegion())) {
97+
x.dump();
98+
}
8599
}
100+
86101
if (!op.isSharedMemoryAlloc())
87102
return failure();
88103
Location loc = op->getLoc();
@@ -99,17 +114,24 @@ struct LocalAllocOpConversion
99114
loc, rewriter);
100115
// If there is an initial tensor, store it into the shared memory.
101116
if (op.getSrc()) {
117+
llvm::dbgs() << "LocalAllocOp adaptor.src():\n";
118+
adaptor.getSrc().dump();
119+
llvm::dbgs() << "LocalAllocOp op.src() and op itself:\n";
120+
op.getSrc().dump();
121+
op.dump();
102122
lowerDistributedToShared(loc, op.getSrc(), op.getResult(),
103123
adaptor.getSrc(), smemObj, typeConverter,
104124
rewriter, targetInfo);
105125
}
106126
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
107127
rewriter.replaceOp(op, retVal);
128+
if (getBoolFromEnv("TR_LONG_IR")) {
108129
llvm::dbgs() << "\n\n===LocalAllocOpConversion AFTER===\n";
109130
retVal.dump();
110131
for (auto& x : *(retVal.getParentRegion())) {
111132
x.dump();
112133
}
134+
}
113135
return success();
114136
}
115137

@@ -195,11 +217,13 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
195217
lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
196218
const LLVMTypeConverter *typeConverter,
197219
ConversionPatternRewriter &rewriter) const {
220+
if (getBoolFromEnv("TR_LONG_IR")) {
198221
llvm::dbgs() << "\n\n===lowerSharedToDistributed BEFORE===\n";
199222
op.dump();
200223
for (auto& x : *(op->getParentRegion())) {
201224
x.dump();
202225
}
226+
}
203227
auto loc = op.getLoc();
204228
auto srcTy = op.getSrc().getType();
205229
auto dstTy = op.getResult().getType();
@@ -214,11 +238,13 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
214238

215239
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
216240
rewriter.replaceOp(op, result);
241+
if (getBoolFromEnv("TR_LONG_IR")) {
217242
llvm::dbgs() << "\n\n===lowerSharedToDistributed AFTER===\n";
218243
result.dump();
219244
for (auto& x : *(result.getParentRegion())) {
220245
x.dump();
221246
}
247+
}
222248
return success();
223249
}
224250

0 commit comments

Comments
 (0)