Skip to content

Commit 8a820f1

Browse files
authored
[mlir][LLVM|ptr] Add the #llvm.address_space attribute, and allow ptr translation (#156333)
This commit introduces the `#llvm.address_space` attribute. This attribute implements the `ptr::MemorySpaceAttrInterface`, establishing the semantics of the LLVM address space. This allows making `!ptr.ptr` translatable to LLVM IR as long it uses the `#llvm.address_space` attribute. Concretely, `!ptr.ptr<#llvm.address_space<N>>` now translates to `ptr addrspace(N)`. Additionally, this patch makes `PtrLikeTypes` with no metadata, no element type, and with `#llvm.address_space` memory space, compatible with the LLVM dialect. **Infrastructure Updates:** - Refactor `ptr::MemorySpaceAttrInterface` to include DataLayout parameter for better validation - Add new utility functions `LLVM::isLoadableType()` and `LLVM::isTypeCompatibleWithAtomicOp()` - Update type compatibility checks to support ptr-like types with LLVM address spaces - Splice the `MemorySpaceAttrInterface` to its own library, so the LLVMDialect won't depend on the PtrDialect yet **Translation Support:** - New `PtrToLLVMIRTranslation` module for converting ptr dialect to LLVM IR - Type translation support for ptr types with LLVM address spaces - Proper address space preservation during IR lowering Example: ```mlir llvm.func @llvm_ops_with_ptr_values(%arg0: !llvm.ptr) { %1 = llvm.load %arg0 : !llvm.ptr -> !ptr.ptr<#llvm.address_space<1>> llvm.store %1, %arg0 : !ptr.ptr<#llvm.address_space<1>>, !llvm.ptr llvm.return } ``` Translates to: ```llvmir ; ModuleID = 'LLVMDialectModule' source_filename = "LLVMDialectModule" define void @llvm_ops_with_ptr_values(ptr %0) { %2 = load ptr addrspace(1), ptr %0, align 8 store ptr addrspace(1) %2, ptr %0, align 8 ret void } !llvm.module.flags = !{!0} !0 = !{i32 2, !"Debug Info Version", i32 3} ```
1 parent 8e4bda1 commit 8a820f1

29 files changed

+420
-22
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
include "mlir/Dialect/LLVMIR/LLVMDialect.td"
1313
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
14+
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
1415
include "mlir/IR/AttrTypeBase.td"
1516
include "mlir/IR/CommonAttrConstraints.td"
1617
include "mlir/Interfaces/DataLayoutInterfaces.td"
@@ -23,6 +24,41 @@ class LLVM_Attr<string name, string attrMnemonic,
2324
let mnemonic = attrMnemonic;
2425
}
2526

27+
//===----------------------------------------------------------------------===//
28+
// AddressSpaceAttr
29+
//===----------------------------------------------------------------------===//
30+
31+
def LLVM_AddressSpaceAttr :
32+
LLVM_Attr<"AddressSpace", "address_space", [
33+
DeclareAttrInterfaceMethods<MemorySpaceAttrInterface>
34+
]> {
35+
let summary = "LLVM address space";
36+
let description = [{
37+
The `address_space` attribute represents an LLVM address space. It takes an
38+
unsigned integer parameter that specifies the address space number.
39+
40+
Different address spaces in LLVM can have different properties:
41+
- Address space 0 is the default/generic address space
42+
- Other address spaces may have specific semantics (e.g., shared memory,
43+
constant memory, etc.) depending on the target architecture
44+
45+
Example:
46+
47+
```mlir
48+
// Address space 0 (default)
49+
#llvm.address_space<0>
50+
51+
// Address space 1 (e.g., global memory on some targets)
52+
#llvm.address_space<1>
53+
54+
// Address space 3 (e.g., shared memory on some GPU targets)
55+
#llvm.address_space<3>
56+
```
57+
}];
58+
let parameters = (ins "unsigned":$addressSpace);
59+
let assemblyFormat = "`<` $addressSpace `>`";
60+
}
61+
2662
//===----------------------------------------------------------------------===//
2763
// CConvAttr
2864
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define MLIR_DIALECT_LLVMIR_LLVMATTRS_H_
1616

1717
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18+
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
1819
#include "mlir/IR/OpImplementation.h"
1920
#include "mlir/Interfaces/DataLayoutInterfaces.h"
2021
#include <optional>

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace mlir {
2828

2929
class AsmParser;
3030
class AsmPrinter;
31+
class DataLayout;
3132

3233
namespace LLVM {
3334
class LLVMDialect;
@@ -111,6 +112,15 @@ bool isCompatibleFloatingPointType(Type type);
111112
/// dialect pointers and LLVM dialect scalable vector types.
112113
bool isCompatibleVectorType(Type type);
113114

115+
/// Returns `true` if the given type is a loadable type compatible with the LLVM
116+
/// dialect.
117+
bool isLoadableType(Type type);
118+
119+
/// Returns true if the given type is supported by atomic operations. All
120+
/// integer, float, and pointer types with a power-of-two bitsize and a minimal
121+
/// size of 8 bits are supported.
122+
bool isTypeCompatibleWithAtomicOp(Type type, const DataLayout &dataLayout);
123+
114124
/// Returns the element count of any LLVM-compatible vector type.
115125
llvm::ElementCount getVectorNumElements(Type type);
116126

mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ mlir_tablegen(PtrOpsAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ptr)
77
add_mlir_dialect_tablegen_target(MLIRPtrOpsAttributesIncGen)
88

99
set(LLVM_TARGET_DEFINITIONS MemorySpaceInterfaces.td)
10-
mlir_tablegen(MemorySpaceInterfaces.h.inc -gen-op-interface-decls)
11-
mlir_tablegen(MemorySpaceInterfaces.cpp.inc -gen-op-interface-defs)
1210
mlir_tablegen(MemorySpaceAttrInterfaces.h.inc -gen-attr-interface-decls)
1311
mlir_tablegen(MemorySpaceAttrInterfaces.cpp.inc -gen-attr-interface-defs)
1412
add_mlir_dialect_tablegen_target(MLIRPtrMemorySpaceInterfacesIncGen)

mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717
#include "mlir/IR/BuiltinAttributes.h"
1818
#include "mlir/IR/OpDefinition.h"
1919

20+
#include <functional>
21+
#include <optional>
22+
2023
namespace mlir {
2124
class Operation;
25+
class DataLayout;
2226
namespace ptr {
2327
enum class AtomicBinOp : uint32_t;
2428
enum class AtomicOrdering : uint32_t;
@@ -27,6 +31,4 @@ enum class AtomicOrdering : uint32_t;
2731

2832
#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.h.inc"
2933

30-
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h.inc"
31-
3234
#endif // MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H

mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> {
4343
/*args=*/ (ins "::mlir::Type":$type,
4444
"::mlir::ptr::AtomicOrdering":$ordering,
4545
"std::optional<int64_t>":$alignment,
46+
"const ::mlir::DataLayout *":$dataLayout,
4647
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
4748
>,
4849
InterfaceMethod<
@@ -58,6 +59,7 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> {
5859
/*args=*/ (ins "::mlir::Type":$type,
5960
"::mlir::ptr::AtomicOrdering":$ordering,
6061
"std::optional<int64_t>":$alignment,
62+
"const ::mlir::DataLayout *":$dataLayout,
6163
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
6264
>,
6365
InterfaceMethod<
@@ -74,6 +76,7 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> {
7476
"::mlir::Type":$type,
7577
"::mlir::ptr::AtomicOrdering":$ordering,
7678
"std::optional<int64_t>":$alignment,
79+
"const ::mlir::DataLayout *":$dataLayout,
7780
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
7881
>,
7982
InterfaceMethod<
@@ -91,6 +94,7 @@ def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> {
9194
"::mlir::ptr::AtomicOrdering":$successOrdering,
9295
"::mlir::ptr::AtomicOrdering":$failureOrdering,
9396
"std::optional<int64_t>":$alignment,
97+
"const ::mlir::DataLayout *":$dataLayout,
9498
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
9599
>,
96100
InterfaceMethod<

mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
#include "llvm/Support/TypeSize.h"
2020

2121
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
22+
#include "mlir/Dialect/Ptr/IR/PtrEnums.h"
2223

2324
#define GET_ATTRDEF_CLASSES
2425
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc"
2526

26-
#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.h.inc"
27-
2827
#endif // MLIR_DIALECT_PTR_IR_PTRATTRS_H
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- PtrEnums.h - `ptr` dialect enums -------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the `ptr` dialect enums.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_PTR_IR_PTRENUMS_H
14+
#define MLIR_DIALECT_PTR_IR_PTRENUMS_H
15+
16+
#include "mlir/IR/BuiltinAttributeInterfaces.h"
17+
#include "mlir/IR/OpImplementation.h"
18+
19+
#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.h.inc"
20+
21+
#endif // MLIR_DIALECT_PTR_IR_PTRENUMS_H

mlir/include/mlir/Target/LLVMIR/Dialect/All.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
2626
#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
2727
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
28+
#include "mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h"
2829
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
2930
#include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h"
3031
#include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h"
@@ -45,6 +46,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
4546
registerNVVMDialectTranslation(registry);
4647
registerOpenACCDialectTranslation(registry);
4748
registerOpenMPDialectTranslation(registry);
49+
registerPtrDialectTranslation(registry);
4850
registerROCDLDialectTranslation(registry);
4951
registerSPIRVDialectTranslation(registry);
5052
registerVCIXDialectTranslation(registry);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- PtrToLLVMIRTranslation.h - `ptr` to LLVM IR --------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This provides registration calls for `ptr` dialect to LLVM IR translation.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H
14+
#define MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H
15+
16+
namespace mlir {
17+
18+
class DialectRegistry;
19+
class MLIRContext;
20+
21+
/// Register the `ptr` dialect and the translation from it to the LLVM IR in the
22+
/// given registry;
23+
void registerPtrDialectTranslation(DialectRegistry &registry);
24+
25+
/// Register the `ptr` dialect and the translation from it in the registry
26+
/// associated with the given context.
27+
void registerPtrDialectTranslation(MLIRContext &context);
28+
29+
} // namespace mlir
30+
31+
#endif // MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H

0 commit comments

Comments
 (0)