@@ -317,8 +317,11 @@ struct BarrierOpLowering : public ConvertOpToLLVMPattern<gpu::BarrierOp> {
317317 matchAndRewrite (gpu::BarrierOp op, OpAdaptor adaptor,
318318 ConversionPatternRewriter &rewriter) const override {
319319 Location loc = op.getLoc ();
320- auto module = op->getParentOfType <ModuleOp>();
321- MLIRContext *context = module .getContext ();
320+ // Declare functions in gpu.module (not top-level module) so they're visible
321+ auto gpuModule = op->getParentOfType <gpu::GPUModuleOp>();
322+ if (!gpuModule)
323+ return failure ();
324+ MLIRContext *context = gpuModule.getContext ();
322325
323326 // Allocate barrier ID (simple counter for now)
324327 // TODO: Proper barrier ID allocation to avoid conflicts
@@ -330,36 +333,36 @@ struct BarrierOpLowering : public ConvertOpToLLVMPattern<gpu::BarrierOp> {
330333 auto barIdConstant = rewriter.create <LLVM::ConstantOp>(
331334 loc, i32Type, rewriter.getI32IntegerAttr (barrierId));
332335
333- // Declare vx_num_warps function to get warp count
334- auto vxNumWarpsFunc = module .lookupSymbol <LLVM::LLVMFuncOp>(" vx_num_warps" );
336+ // Declare vx_num_warps function in gpu.module if not already declared
337+ auto vxNumWarpsFunc = gpuModule .lookupSymbol <LLVM::LLVMFuncOp>(" vx_num_warps" );
335338 if (!vxNumWarpsFunc) {
336339 OpBuilder::InsertionGuard guard (rewriter);
337- rewriter.setInsertionPointToStart (module .getBody ());
340+ rewriter.setInsertionPointToStart (gpuModule .getBody ());
338341
339342 auto funcType = LLVM::LLVMFunctionType::get (
340343 i32Type, {}, /* isVarArg=*/ false );
341344
342345 vxNumWarpsFunc = rewriter.create <LLVM::LLVMFuncOp>(
343- module .getLoc (), " vx_num_warps" , funcType);
346+ gpuModule .getLoc (), " vx_num_warps" , funcType);
344347 }
345348
346349 // Call vx_num_warps() to get number of warps
347350 auto numWarps = rewriter.create <LLVM::CallOp>(
348351 loc, vxNumWarpsFunc, ValueRange{});
349352
350- // Declare vx_barrier function if not already declared
351- auto vxBarrierFunc = module .lookupSymbol <LLVM::LLVMFuncOp>(" vx_barrier" );
353+ // Declare vx_barrier function in gpu.module if not already declared
354+ auto vxBarrierFunc = gpuModule .lookupSymbol <LLVM::LLVMFuncOp>(" vx_barrier" );
352355 if (!vxBarrierFunc) {
353356 OpBuilder::InsertionGuard guard (rewriter);
354- rewriter.setInsertionPointToStart (module .getBody ());
357+ rewriter.setInsertionPointToStart (gpuModule .getBody ());
355358
356359 auto funcType = LLVM::LLVMFunctionType::get (
357360 LLVM::LLVMVoidType::get (context),
358361 {i32Type, i32Type},
359362 /* isVarArg=*/ false );
360363
361364 vxBarrierFunc = rewriter.create <LLVM::LLVMFuncOp>(
362- module .getLoc (), " vx_barrier" , funcType);
365+ gpuModule .getLoc (), " vx_barrier" , funcType);
363366 }
364367
365368 // Call vx_barrier(barrier_id, num_warps)
@@ -374,10 +377,10 @@ struct BarrierOpLowering : public ConvertOpToLLVMPattern<gpu::BarrierOp> {
374377 }
375378};
376379
377- // / Lower printf calls to vx_printf with core ID as first argument
380+ // / Lower printf calls to vx_printf
378381// / Matches: llvm.call @printf(format, args...)
379- // / Replaces with: llvm.call @vx_printf(format, cid, args...)
380- // / where cid = vx_core_id()
382+ // / Replaces with: llvm.call @vx_printf(format, args...)
383+ // / vx_printf has the same signature as standard printf
381384struct PrintfOpLowering : public OpRewritePattern <LLVM::CallOp> {
382385 using OpRewritePattern<LLVM::CallOp>::OpRewritePattern;
383386
@@ -401,17 +404,6 @@ struct PrintfOpLowering : public OpRewritePattern<LLVM::CallOp> {
401404 MLIRContext *context = gpuModule.getContext ();
402405 auto i32Type = rewriter.getI32Type ();
403406
404- // Declare vx_core_id function in gpu.module if not already declared
405- auto vxCoreIdFunc = gpuModule.lookupSymbol <LLVM::LLVMFuncOp>(" vx_core_id" );
406- if (!vxCoreIdFunc) {
407- OpBuilder::InsertionGuard guard (rewriter);
408- rewriter.setInsertionPointToStart (gpuModule.getBody ());
409-
410- auto funcType = LLVM::LLVMFunctionType::get (i32Type, {}, /* isVarArg=*/ false );
411- vxCoreIdFunc = rewriter.create <LLVM::LLVMFuncOp>(
412- gpuModule.getLoc (), " vx_core_id" , funcType);
413- }
414-
415407 // Declare vx_printf function in gpu.module if not already declared
416408 auto vxPrintfFunc = gpuModule.lookupSymbol <LLVM::LLVMFuncOp>(" vx_printf" );
417409 if (!vxPrintfFunc) {
@@ -424,17 +416,9 @@ struct PrintfOpLowering : public OpRewritePattern<LLVM::CallOp> {
424416 gpuModule.getLoc (), " vx_printf" , funcType);
425417 }
426418
427- // Call vx_core_id() to get core ID
428- auto coreIdCall = rewriter.create <LLVM::CallOp>(loc, vxCoreIdFunc, ValueRange{});
429- Value coreId = coreIdCall.getResult ();
430-
431- // Build new argument list: format, cid, original_args...
419+ // Build argument list: pass all original arguments unchanged
432420 SmallVector<Value> newArgs;
433- newArgs.push_back (callOp.getOperand (0 )); // format string (first arg)
434- newArgs.push_back (coreId); // core ID (new second arg)
435-
436- // Add remaining original arguments (skip format which is operand 0)
437- for (unsigned i = 1 ; i < callOp.getNumOperands (); ++i) {
421+ for (unsigned i = 0 ; i < callOp.getNumOperands (); ++i) {
438422 newArgs.push_back (callOp.getOperand (i));
439423 }
440424
0 commit comments