diff --git a/src/enzyme_ad/jax/Passes/ConvertParallelToGPU.cpp b/src/enzyme_ad/jax/Passes/ConvertParallelToGPU.cpp index b17aa90110..5f90c7b125 100644 --- a/src/enzyme_ad/jax/Passes/ConvertParallelToGPU.cpp +++ b/src/enzyme_ad/jax/Passes/ConvertParallelToGPU.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Dominance.h" @@ -2510,6 +2511,13 @@ gdgo->erase(); continue; if (s0.getValue() == "target-cpu") sm = s1.getValue(); + if (backend == "rocm") { + if (sm.find("sm_") != std::string::npos) { + llvm::errs() << "Error: Found NVIDIA architecture while " + "targeting ROCm.\n"; + std::abort(); + } + } } } } @@ -2526,15 +2534,25 @@ gdgo->erase(); return; auto gmod = cast(gfunc->getParentOp()); if (!gmod.getTargetsAttr()) { - auto chip = sm; - if (chip.size() == 0) - chip = "sm_80"; - auto features = feat; - if (features.size() == 0) - features = "+ptx73"; - auto target = NVVM::NVVMTargetAttr::get( - gmod.getContext(), /*optLevel*/ 2, - /*triple*/ "nvptx64-nvidia-cuda", chip, features); + Attribute target; + if (backend == "rocm") { + auto chip = "gfx900"; + auto features = "+wavefront64"; + target = ROCDL::ROCDLTargetAttr::get( + gmod.getContext(), + /*optLevel=*/3, /*triple=*/"amdgcn-amd-amdhsa", chip, features, + /*abiVersion=*/"600"); + } else { + auto chip = sm; + if (chip.size() == 0) + chip = "sm_80"; + auto features = feat; + if (features.size() == 0) + features = "+ptx73"; + target = NVVM::NVVMTargetAttr::get( + gmod.getContext(), /*optLevel*/ 3, + /*triple*/ "nvptx64-nvidia-cuda", chip, features); + } gmod.setTargetsAttr(ArrayAttr::get(gmod.getContext(), target)); DataLayoutSpecInterface dataLayout = {}; diff --git a/src/enzyme_ad/jax/Passes/GPULaunchRecognition.cpp b/src/enzyme_ad/jax/Passes/GPULaunchRecognition.cpp index 075fd21370..ea3c91a571 100644 --- a/src/enzyme_ad/jax/Passes/GPULaunchRecognition.cpp +++ b/src/enzyme_ad/jax/Passes/GPULaunchRecognition.cpp @@ -6,6 +6,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/IRMapping.h" #include "src/enzyme_ad/jax/Dialect/Dialect.h" @@ -39,7 +40,7 @@ struct GPULaunchRecognitionPass gpuModule = gpu::GPUModuleOp::create( moduleBuilder, getOperation()->getLoc(), gpuModuleName); - std::string sm; + std::string sm; // NVIDIA Streaming Multiprocessor (sm_80) if (auto attr = dyn_cast_or_null(func.getPassthroughAttr())) { for (auto a : attr) { if (auto ar = dyn_cast(a)) { @@ -54,24 +55,37 @@ struct GPULaunchRecognitionPass } } } + std::string feat; if (auto attr = dyn_cast_or_null( func.getTargetFeaturesAttr())) { feat = attr.getFeaturesString(); } - auto chip = sm; - if (chip.size() == 0) - chip = "sm_80"; - auto features = feat; - if (features.size() == 0) - features = "+ptx73"; - - // TODO get these target attrs from somewhere - auto target = moduleBuilder.getAttr( - /*optLevel=*/2, /*triple=*/"nvptx64-nvidia-cuda", chip, features, - /*flags=*/nullptr, - /*linkLibs=*/nullptr); + Attribute target; + if (backend == "rocm") { + // Here temporarily set as "" for ROCm backend + auto chip = ""; + auto features = "+wavefrontsize64"; + + target = moduleBuilder.getAttr( + /*optLevel=*/3, /*triple=*/"amdgcn-amd-amdhsa", chip, features, + /*abiVersion=*/"", + /*flags=*/nullptr, + /*linkLibs=*/nullptr); + } else { + // Default to CUDA/NVVM + auto chip = sm; + if (chip.size() == 0) + chip = "sm_80"; + auto features = feat; + if (features.size() == 0) + features = "+ptx73"; + target = moduleBuilder.getAttr( + /*optLevel=*/3, /*triple=*/"nvptx64-nvidia-cuda", chip, features, + /*flags=*/nullptr, + /*linkLibs=*/nullptr); + } gpuModule.setTargetsAttr(moduleBuilder.getArrayAttr({target})); DataLayoutSpecInterface dataLayout = {}; diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 356aab7b46..ad5bb444f0 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -806,20 +806,29 @@ def CuDNNHLOOpt : Pass<"enzymexla-cudnn-hlo-opt"> { } def GPULaunchRecognition : Pass<"gpu-launch-recognition"> { - let summary = "Optimize stablehlo to emit cuDNN specific optimizations"; + let summary = "Recognize and convert GPU kernel launches to GPU dialect operations"; let dependentDialects = [ "enzymexla::EnzymeXLADialect", "arith::ArithDialect", "gpu::GPUDialect", "mlir::NVVM::NVVMDialect", + "mlir::ROCDL::ROCDLDialect", "mlir::DLTIDialect" ]; - let options = [Option< + let options = [ + Option< /*C++ variable name=*/"use_launch_func", /*CLI argument=*/"use_launch_func", /*type=*/"bool", /*default=*/"false", - /*description=*/"Convert Periodic Concat to Manual Computation with CollectivePermute">]; + /*description=*/"Convert Periodic Concat to Manual Computation with CollectivePermute">, + Option< + /*C++ variable name=*/"backend", + /*CLI argument=*/"backend", + /*type=*/"std::string", + /*default=*/"\"cuda\"", + /*description=*/"HW backend">, + ]; } def MergeGPUModulesPass : Pass<"merge-gpu-modules", "mlir::ModuleOp"> { @@ -837,13 +846,21 @@ def ConvertParallelToGPU1 : Pass<"convert-parallel-to-gpu1"> { def ConvertParallelToGPU2 : Pass<"convert-parallel-to-gpu2"> { let summary = "Convert parallel loops to gpu"; - let dependentDialects = ["func::FuncDialect", "LLVM::LLVMDialect", "memref::MemRefDialect", "gpu::GPUDialect", "mlir::NVVM::NVVMDialect"]; - let options = [Option< + let dependentDialects = ["func::FuncDialect", "LLVM::LLVMDialect", "memref::MemRefDialect", "gpu::GPUDialect", "mlir::NVVM::NVVMDialect", "mlir::ROCDL::ROCDLDialect"]; + let options = [ + Option< /*C++ variable name=*/"emitGPUKernelLaunchBounds", /*CLI argument=*/"emitGPUKernelLaunchBounds", /*type=*/"bool", /*default=*/"false", - /*description=*/"Convert Periodic Concat to Manual Computation with CollectivePermute">]; + /*description=*/"Convert Periodic Concat to Manual Computation with CollectivePermute">, + Option< + /*C++ variable name=*/"backend", + /*CLI argument=*/"backend", + /*type=*/"std::string", + /*default=*/"\"cuda\"", + /*description=*/"Target backend (cuda or rocm)"> + ]; } def PolygeistMem2Reg : Pass<"polygeist-mem2reg"> { diff --git a/src/enzyme_ad/jax/raise.cpp b/src/enzyme_ad/jax/raise.cpp index ee964b76e3..eca90b99dc 100644 --- a/src/enzyme_ad/jax/raise.cpp +++ b/src/enzyme_ad/jax/raise.cpp @@ -85,7 +85,11 @@ extern "C" std::string runLLVMToMLIRRoundTrip(std::string input, std::string pass_pipeline = "inline{default-pipeline=canonicalize " "max-iterations=4},sroa-wrappers{set_private=false attributor=false},gpu-launch-" - "recognition,canonicalize,libdevice-funcs-raise,canonicalize,symbol-dce,"; + "recognition{backend="; + pass_pipeline += backend; + pass_pipeline += "}"; + pass_pipeline += "," + "canonicalize,libdevice-funcs-raise,canonicalize,symbol-dce,"; if (backend == "cpu") pass_pipeline += "parallel-lower{wrapParallelOps=false},"; @@ -128,7 +132,11 @@ extern "C" std::string runLLVMToMLIRRoundTrip(std::string input, } pass_pipeline += "symbol-dce,enzyme,remove-unnecessary-enzyme-ops,lower-affine"; if (backend != "cpu") - pass_pipeline += ",convert-parallel-to-gpu1,gpu-kernel-outlining,canonicalize,convert-parallel-to-gpu2,lower-affine"; + pass_pipeline += ",convert-parallel-to-gpu1,gpu-kernel-outlining,canonicalize,convert-parallel-to-gpu2{backend="; + pass_pipeline += backend; + pass_pipeline += "}"; + pass_pipeline += "," + "lower-affine"; if (getenv("REACTANT_OMP")) { pass_pipeline += ",convert-scf-to-openmp,"; } else { diff --git a/test/lit_tests/lowering/gpu-recognize2.mlir b/test/lit_tests/lowering/gpu-recognize2.mlir index a626a2372f..739eccd633 100644 --- a/test/lit_tests/lowering/gpu-recognize2.mlir +++ b/test/lit_tests/lowering/gpu-recognize2.mlir @@ -1,4 +1,5 @@ // RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(gpu-launch-recognition)" | FileCheck %s +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(gpu-launch-recognition{backend=rocm})" | FileCheck %s --check-prefix=CHECK-ROCM #tbaa_root = #llvm.tbaa_root #tbaa_type_desc = #llvm.tbaa_type_desc}> @@ -68,7 +69,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec = dense<32> : vec llvm.func local_unnamed_addr @__mlir_cuda_caller_phase3(...) attributes {sym_visibility = "private"} } -// CHECK: gpu.module @__mlir_gpu_module [#nvvm.target] attributes {dlti.dl_spec = #dlti.dl_spec} { +// CHECK: gpu.module @__mlir_gpu_module [#nvvm.target] attributes {dlti.dl_spec = #dlti.dl_spec} { // CHECK-NEXT: gpu.func @reactant$_Z18__device_stub__fooPi(%arg0: !llvm.ptr {llvm.nocapture, llvm.noundef, llvm.writeonly}) kernel { // CHECK-NEXT: %0 = nvvm.read.ptx.sreg.tid.x : i32 // CHECK-NEXT: %1 = llvm.zext nneg %0 : i32 to i64 @@ -144,3 +145,80 @@ module attributes {dlti.dl_spec = #dlti.dl_spec = dense<32> : vec // CHECK-NEXT: llvm.store %0, %2 {alignment = 4 : i64, tbaa = [#tbaa_tag]} : i32, !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } + +// CHECK-ROCM: gpu.module @__mlir_gpu_module [#rocdl.target] attributes {dlti.dl_spec = #dlti.dl_spec} { +// CHECK-ROCM-NEXT: gpu.func @reactant$_Z18__device_stub__fooPi(%arg0: !llvm.ptr {llvm.nocapture, llvm.noundef, llvm.writeonly}) kernel { +// CHECK-ROCM-NEXT: %0 = nvvm.read.ptx.sreg.tid.x : i32 +// CHECK-ROCM-NEXT: %1 = llvm.zext nneg %0 : i32 to i64 +// CHECK-ROCM-NEXT: %2 = llvm.getelementptr inbounds|nuw %arg0[%1] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-ROCM-NEXT: llvm.store %0, %2 {alignment = 4 : i64, tbaa = [#tbaa_tag]} : i32, !llvm.ptr +// CHECK-ROCM-NEXT: gpu.return +// CHECK-ROCM-NEXT: } +// CHECK-ROCM-NEXT: } +// CHECK-ROCM-NEXT: llvm.comdat @__llvm_global_comdat { +// CHECK-ROCM-NEXT: } +// CHECK-ROCM-NEXT: llvm.module_flags [#llvm.mlir.module_flag, #llvm.mlir.module_flag, #llvm.mlir.module_flag, #llvm.mlir.module_flag, #llvm.mlir.module_flag, #llvm.mlir.module_flag] +// CHECK-ROCM-NEXT: llvm.mlir.global private unnamed_addr constant @".str"("res = %d\0A\00") {addr_space = 0 : i32, alignment = 1 : i64, dso_local, sym_visibility = "private"} +// CHECK-ROCM-NEXT: llvm.func local_unnamed_addr @main() -> (i32 {llvm.noundef}) attributes {dso_local, no_infs_fp_math = true, no_nans_fp_math = true, no_signed_zeros_fp_math = true, passthrough = ["mustprogress", "norecurse", ["approx-func-fp-math", "true"], ["min-legal-vector-width", "0"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "x86-64"]], target_cpu = "x86-64", target_features = #llvm.target_features<["+cmov", "+cx8", "+fxsr", "+mmx", "+sse", "+sse2", "+x87"]>, tune_cpu = "generic", unsafe_fp_math = true, uwtable_kind = #llvm.uwtableKind} { +// CHECK-ROCM-NEXT: %0 = llvm.mlir.constant(1 : i32) : i32 +// CHECK-ROCM-NEXT: %1 = llvm.mlir.constant(512 : i64) : i64 +// CHECK-ROCM-NEXT: %2 = llvm.mlir.constant(true) : i1 +// CHECK-ROCM-NEXT: %3 = "enzymexla.gpu_kernel_address"() <{fn = @__mlir_gpu_module::@reactant$_Z18__device_stub__fooPi}> : () -> !llvm.ptr +// CHECK-ROCM-NEXT: %4 = llvm.mlir.constant(128 : i32) : i32 +// CHECK-ROCM-NEXT: %5 = llvm.mlir.constant(0 : i64) : i64 +// CHECK-ROCM-NEXT: %6 = llvm.mlir.zero : !llvm.ptr +// CHECK-ROCM-NEXT: %7 = llvm.mlir.constant(0 : i32) : i32 +// CHECK-ROCM-NEXT: %8 = llvm.mlir.addressof @".str" : !llvm.ptr +// CHECK-ROCM-NEXT: %9 = llvm.mlir.constant(2 : i32) : i32 +// CHECK-ROCM-NEXT: %10 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr +// CHECK-ROCM-NEXT: %11 = llvm.alloca %0 x !llvm.array<128 x i32> {alignment = 16 : i64} : (i32) -> !llvm.ptr +// CHECK-ROCM-NEXT: %12 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr +// CHECK-ROCM-NEXT: llvm.intr.lifetime.start %10 : !llvm.ptr +// CHECK-ROCM-NEXT: %13 = arith.index_cast %1 : i64 to index +// CHECK-ROCM-NEXT: %memref = gpu.alloc (%13) : memref +// CHECK-ROCM-NEXT: %14 = "enzymexla.memref2pointer"(%memref) : (memref) -> !llvm.ptr +// CHECK-ROCM-NEXT: llvm.store %14, %10 : !llvm.ptr, !llvm.ptr +// CHECK-ROCM-NEXT: %15 = llvm.mlir.zero : i32 +// CHECK-ROCM-NEXT: llvm.intr.lifetime.start %11 : !llvm.ptr +// CHECK-ROCM-NEXT: llvm.cond_br %2, ^bb1, ^bb2 +// CHECK-ROCM-NEXT: ^bb1: // pred: ^bb0 +// CHECK-ROCM-NEXT: %16 = llvm.load %10 {alignment = 8 : i64, tbaa = [#tbaa_tag1]} : !llvm.ptr -> !llvm.ptr +// CHECK-ROCM-NEXT: %17 = llvm.trunc %5 : i64 to i32 +// CHECK-ROCM-NEXT: %18 = llvm.sext %0 : i32 to i64 +// CHECK-ROCM-NEXT: %19 = llvm.sext %0 : i32 to i64 +// CHECK-ROCM-NEXT: %20 = llvm.sext %0 : i32 to i64 +// CHECK-ROCM-NEXT: %21 = llvm.sext %4 : i32 to i64 +// CHECK-ROCM-NEXT: %22 = llvm.sext %0 : i32 to i64 +// CHECK-ROCM-NEXT: %23 = llvm.sext %0 : i32 to i64 +// CHECK-ROCM-NEXT: gpu.launch_func @__mlir_gpu_module::@reactant$_Z18__device_stub__fooPi blocks in (%18, %19, %20) threads in (%21, %22, %23) : i64 dynamic_shared_memory_size %17 args(%16 : !llvm.ptr) +// CHECK-ROCM-NEXT: llvm.br ^bb2 +// CHECK-ROCM-NEXT: ^bb2: // 2 preds: ^bb0, ^bb1 +// CHECK-ROCM-NEXT: llvm.intr.lifetime.start %12 : !llvm.ptr +// CHECK-ROCM-NEXT: llvm.store %7, %12 {alignment = 4 : i64, tbaa = [#tbaa_tag]} : i32, !llvm.ptr +// CHECK-ROCM-NEXT: %24 = "enzymexla.gpu_occupancy"(%0, %5, %7) <{fn = @__mlir_gpu_module::@reactant$_Z18__device_stub__fooPi}> : (i32, i64, i32) -> i32 +// CHECK-ROCM-NEXT: llvm.store %24, %12 : i32, !llvm.ptr +// CHECK-ROCM-NEXT: %25 = llvm.mlir.zero : i32 +// CHECK-ROCM-NEXT: %26 = llvm.alloca %0 x !llvm.struct<"struct.cudaFuncAttributes", (i64, i64, i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, array<16 x i32>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr +// CHECK-ROCM-NEXT: %27 = llvm.call @cudaFuncGetAttributes(%26, %3) {no_unwind} : (!llvm.ptr {llvm.nonnull, llvm.noundef}, !llvm.ptr {llvm.nonnull, llvm.noundef}) -> (i32 {llvm.noundef}) +// CHECK-ROCM-NEXT: %28 = llvm.load %12 {alignment = 4 : i64, tbaa = [#tbaa_tag]} : !llvm.ptr -> i32 +// CHECK-ROCM-NEXT: %29 = llvm.call @printf(%8, %28) vararg(!llvm.func) : (!llvm.ptr {llvm.dereferenceable = 1 : i64, llvm.nonnull, llvm.noundef}, i32 {llvm.noundef}) -> i32 +// CHECK-ROCM-NEXT: %30 = llvm.load %10 {alignment = 8 : i64, tbaa = [#tbaa_tag1]} : !llvm.ptr -> !llvm.ptr +// CHECK-ROCM-NEXT: %31 = "enzymexla.pointer2memref"(%11) : (!llvm.ptr) -> memref +// CHECK-ROCM-NEXT: %32 = "enzymexla.pointer2memref"(%30) : (!llvm.ptr) -> memref +// CHECK-ROCM-NEXT: %33 = arith.index_cast %1 : i64 to index +// CHECK-ROCM-NEXT: enzymexla.memcpy %31, %32, %33 : memref, memref +// CHECK-ROCM-NEXT: %34 = llvm.mlir.zero : i32 +// CHECK-ROCM-NEXT: %35 = llvm.load %10 {alignment = 8 : i64, tbaa = [#tbaa_tag1]} : !llvm.ptr -> !llvm.ptr +// CHECK-ROCM-NEXT: %36 = llvm.call @cudaFree(%35) : (!llvm.ptr {llvm.noundef}) -> i32 +// CHECK-ROCM-NEXT: llvm.intr.lifetime.end %12 : !llvm.ptr +// CHECK-ROCM-NEXT: llvm.intr.lifetime.end %11 : !llvm.ptr +// CHECK-ROCM-NEXT: llvm.intr.lifetime.end %10 : !llvm.ptr +// CHECK-ROCM-NEXT: llvm.return %7 : i32 +// CHECK-ROCM-NEXT: } +// CHECK-ROCM: llvm.func internal @reactant$_Z18__device_stub__fooPi(%arg0: !llvm.ptr {llvm.nocapture, llvm.noundef, llvm.writeonly}) attributes {dso_local, frame_pointer = #llvm.framePointerKind, no_infs_fp_math = true, no_inline, no_nans_fp_math = true, no_signed_zeros_fp_math = true, no_unwind, passthrough = ["mustprogress", "nofree", "norecurse", "nosync", ["approx-func-fp-math", "true"], ["no-trapping-math", "true"], ["stack-protector-buffer-size", "8"], ["target-cpu", "sm_120"], ["uniform-work-group-size", "true"]], sym_visibility = "private", target_cpu = "sm_120", target_features = #llvm.target_features<["+ptx88", "+sm_120"]>, unsafe_fp_math = true, will_return} { +// CHECK-ROCM-NEXT: %0 = nvvm.read.ptx.sreg.tid.x : i32 +// CHECK-ROCM-NEXT: %1 = llvm.zext nneg %0 : i32 to i64 +// CHECK-ROCM-NEXT: %2 = llvm.getelementptr inbounds|nuw %arg0[%1] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-ROCM-NEXT: llvm.store %0, %2 {alignment = 4 : i64, tbaa = [#tbaa_tag]} : i32, !llvm.ptr +// CHECK-ROCM-NEXT: llvm.return +// CHECK-ROCM-NEXT: } \ No newline at end of file diff --git a/test/lit_tests/lowering/gpu-recognize3.mlir b/test/lit_tests/lowering/gpu-recognize3.mlir index af889c7b5a..c49e8c0e8c 100644 --- a/test/lit_tests/lowering/gpu-recognize3.mlir +++ b/test/lit_tests/lowering/gpu-recognize3.mlir @@ -114,4 +114,4 @@ module attributes {dlti.dl_spec = #dlti.dl_spec = dense<32> : vec // CHECK-NEXT: llvm.intr.lifetime.end %11 : !llvm.ptr // CHECK-NEXT: llvm.intr.lifetime.end %10 : !llvm.ptr // CHECK-NEXT: llvm.return %7 : i32 -// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file