Skip to content

Commit 83e1f23

Browse files
author
ymweiss
committed
[fix] Fix barrier and printf lowering in GPUToVortex pass
Two fixes to the GPUToVortex lowering pass: 1. Barrier lowering: Declare vx_barrier and vx_num_warps functions inside gpu.module instead of top-level module, so they are visible to kernel code during compilation. 2. Printf lowering: Remove incorrect core_id insertion. vx_printf has the same signature as standard printf (no core_id parameter). Previously the pass was corrupting printf arguments.
1 parent 3c31e91 commit 83e1f23

File tree

1 file changed

+18
-34
lines changed

1 file changed

+18
-34
lines changed

lib/polygeist/Passes/ConvertGPUToVortex.cpp

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,11 @@ struct BarrierOpLowering : public ConvertOpToLLVMPattern<gpu::BarrierOp> {
317317
matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor,
318318
ConversionPatternRewriter &rewriter) const override {
319319
Location loc = op.getLoc();
320-
auto module = op->getParentOfType<ModuleOp>();
321-
MLIRContext *context = module.getContext();
320+
// Declare functions in gpu.module (not top-level module) so they're visible
321+
auto gpuModule = op->getParentOfType<gpu::GPUModuleOp>();
322+
if (!gpuModule)
323+
return failure();
324+
MLIRContext *context = gpuModule.getContext();
322325

323326
// Allocate barrier ID (simple counter for now)
324327
// TODO: Proper barrier ID allocation to avoid conflicts
@@ -330,36 +333,36 @@ struct BarrierOpLowering : public ConvertOpToLLVMPattern<gpu::BarrierOp> {
330333
auto barIdConstant = rewriter.create<LLVM::ConstantOp>(
331334
loc, i32Type, rewriter.getI32IntegerAttr(barrierId));
332335

333-
// Declare vx_num_warps function to get warp count
334-
auto vxNumWarpsFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("vx_num_warps");
336+
// Declare vx_num_warps function in gpu.module if not already declared
337+
auto vxNumWarpsFunc = gpuModule.lookupSymbol<LLVM::LLVMFuncOp>("vx_num_warps");
335338
if (!vxNumWarpsFunc) {
336339
OpBuilder::InsertionGuard guard(rewriter);
337-
rewriter.setInsertionPointToStart(module.getBody());
340+
rewriter.setInsertionPointToStart(gpuModule.getBody());
338341

339342
auto funcType = LLVM::LLVMFunctionType::get(
340343
i32Type, {}, /*isVarArg=*/false);
341344

342345
vxNumWarpsFunc = rewriter.create<LLVM::LLVMFuncOp>(
343-
module.getLoc(), "vx_num_warps", funcType);
346+
gpuModule.getLoc(), "vx_num_warps", funcType);
344347
}
345348

346349
// Call vx_num_warps() to get number of warps
347350
auto numWarps = rewriter.create<LLVM::CallOp>(
348351
loc, vxNumWarpsFunc, ValueRange{});
349352

350-
// Declare vx_barrier function if not already declared
351-
auto vxBarrierFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("vx_barrier");
353+
// Declare vx_barrier function in gpu.module if not already declared
354+
auto vxBarrierFunc = gpuModule.lookupSymbol<LLVM::LLVMFuncOp>("vx_barrier");
352355
if (!vxBarrierFunc) {
353356
OpBuilder::InsertionGuard guard(rewriter);
354-
rewriter.setInsertionPointToStart(module.getBody());
357+
rewriter.setInsertionPointToStart(gpuModule.getBody());
355358

356359
auto funcType = LLVM::LLVMFunctionType::get(
357360
LLVM::LLVMVoidType::get(context),
358361
{i32Type, i32Type},
359362
/*isVarArg=*/false);
360363

361364
vxBarrierFunc = rewriter.create<LLVM::LLVMFuncOp>(
362-
module.getLoc(), "vx_barrier", funcType);
365+
gpuModule.getLoc(), "vx_barrier", funcType);
363366
}
364367

365368
// Call vx_barrier(barrier_id, num_warps)
@@ -374,10 +377,10 @@ struct BarrierOpLowering : public ConvertOpToLLVMPattern<gpu::BarrierOp> {
374377
}
375378
};
376379

377-
/// Lower printf calls to vx_printf with core ID as first argument
380+
/// Lower printf calls to vx_printf
378381
/// Matches: llvm.call @printf(format, args...)
379-
/// Replaces with: llvm.call @vx_printf(format, cid, args...)
380-
/// where cid = vx_core_id()
382+
/// Replaces with: llvm.call @vx_printf(format, args...)
383+
/// vx_printf has the same signature as standard printf
381384
struct PrintfOpLowering : public OpRewritePattern<LLVM::CallOp> {
382385
using OpRewritePattern<LLVM::CallOp>::OpRewritePattern;
383386

@@ -401,17 +404,6 @@ struct PrintfOpLowering : public OpRewritePattern<LLVM::CallOp> {
401404
MLIRContext *context = gpuModule.getContext();
402405
auto i32Type = rewriter.getI32Type();
403406

404-
// Declare vx_core_id function in gpu.module if not already declared
405-
auto vxCoreIdFunc = gpuModule.lookupSymbol<LLVM::LLVMFuncOp>("vx_core_id");
406-
if (!vxCoreIdFunc) {
407-
OpBuilder::InsertionGuard guard(rewriter);
408-
rewriter.setInsertionPointToStart(gpuModule.getBody());
409-
410-
auto funcType = LLVM::LLVMFunctionType::get(i32Type, {}, /*isVarArg=*/false);
411-
vxCoreIdFunc = rewriter.create<LLVM::LLVMFuncOp>(
412-
gpuModule.getLoc(), "vx_core_id", funcType);
413-
}
414-
415407
// Declare vx_printf function in gpu.module if not already declared
416408
auto vxPrintfFunc = gpuModule.lookupSymbol<LLVM::LLVMFuncOp>("vx_printf");
417409
if (!vxPrintfFunc) {
@@ -424,17 +416,9 @@ struct PrintfOpLowering : public OpRewritePattern<LLVM::CallOp> {
424416
gpuModule.getLoc(), "vx_printf", funcType);
425417
}
426418

427-
// Call vx_core_id() to get core ID
428-
auto coreIdCall = rewriter.create<LLVM::CallOp>(loc, vxCoreIdFunc, ValueRange{});
429-
Value coreId = coreIdCall.getResult();
430-
431-
// Build new argument list: format, cid, original_args...
419+
// Build argument list: pass all original arguments unchanged
432420
SmallVector<Value> newArgs;
433-
newArgs.push_back(callOp.getOperand(0)); // format string (first arg)
434-
newArgs.push_back(coreId); // core ID (new second arg)
435-
436-
// Add remaining original arguments (skip format which is operand 0)
437-
for (unsigned i = 1; i < callOp.getNumOperands(); ++i) {
421+
for (unsigned i = 0; i < callOp.getNumOperands(); ++i) {
438422
newArgs.push_back(callOp.getOperand(i));
439423
}
440424

0 commit comments

Comments
 (0)