diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 5c9f9dfaac..fd6a9d9d95 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -1283,11 +1283,13 @@ cc_library( "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:LLVMIRTransforms", "@llvm-project//mlir:NVVMTarget", + "@llvm-project//mlir:ROCDLTarget", "@llvm-project//mlir:LinalgTransformOps", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:ComplexToLLVM", "@llvm-project//mlir:IndexToLLVM", @@ -1308,6 +1310,7 @@ cc_library( "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:ROCDLDialect", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp index 05b17ca64d..d62984b8b0 100644 --- a/src/enzyme_ad/jax/RegistryUtils.cpp +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -44,6 +44,7 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" @@ -62,6 +63,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVM/ROCDL/Target.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" @@ -84,6 +86,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "src/enzyme_ad/jax/Dialect/Ops.h" #include "src/enzyme_ad/jax/Passes/Passes.h" @@ -196,6 +199,7 @@ void registerDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -237,6 +241,7 @@ void loadAllRegisteredDialects(mlir::MLIRContext &context) { context.loadDialect(); context.loadDialect(); context.loadDialect(); + context.loadDialect(); context.loadDialect(); context.loadDialect(); context.loadDialect(); @@ -281,11 +286,13 @@ void registerInterfaces(mlir::DialectRegistry ®istry) { mlir::registerConvertMemRefToLLVMInterface(registry); mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry); + mlir::ROCDL::registerROCDLTargetInterfaceExternalModels(registry); mlir::registerBuiltinDialectTranslation(registry); mlir::registerGPUDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); mlir::registerNVVMDialectTranslation(registry); mlir::registerOpenMPDialectTranslation(registry); + mlir::registerROCDLDialectTranslation(registry); mlir::registerConvertOpenMPToLLVMInterface(registry); mlir::vector::registerConvertVectorToLLVMInterface(registry);