Skip to content

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Oct 3, 2025

This patch introduces the omp.alloc_shared_mem and omp.free_shared_mem operations to represent explicit allocations and deallocations of shared memory across threads in a team, mirroring the existing omp.target_allocmem and omp.target_freemem.

The omp.alloc_shared_mem op goes through the same Flang-specific transformations as omp.target_allocmem, so that the size of the buffer can be properly calculated when translating to LLVM IR.

The corresponding runtime functions produced for these new operations are __kmpc_alloc_shared and __kmpc_free_shared, which previously could only be created for implicit allocations (e.g. privatized and reduction variables).

This patch introduces the `omp.alloc_shared_mem` and `omp.free_shared_mem`
operations to represent explicit allocations and deallocations of shared memory
across threads in a team, mirroring the existing `omp.target_allocmem` and
`omp.target_freemem`.

The `omp.alloc_shared_mem` op goes through the same Flang-specific
transformations as `omp.target_allocmem`, so that the size of the buffer can be
properly calculated when translating to LLVM IR.

The corresponding runtime functions produced for these new operations are
`__kmpc_alloc_shared` and `__kmpc_free_shared`, which previously could only be
created for implicit allocations (e.g. privatized and reduction variables).
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2025

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-flang-codegen

Author: Sergio Afonso (skatrak)

Changes

This patch introduces the omp.alloc_shared_mem and omp.free_shared_mem operations to represent explicit allocations and deallocations of shared memory across threads in a team, mirroring the existing omp.target_allocmem and omp.target_freemem.

The omp.alloc_shared_mem op goes through the same Flang-specific transformations as omp.target_allocmem, so that the size of the buffer can be properly calculated when translating to LLVM IR.

The corresponding runtime functions produced for these new operations are __kmpc_alloc_shared and __kmpc_free_shared, which previously could only be created for implicit allocations (e.g. privatized and reduction variables).


Patch is 21.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161862.diff

8 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp (+28-15)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+23)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+21-8)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+62)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+22)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+55-13)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+28)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+29-2)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
index 381b2a29c517a..c1a6b06d6a52b 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
@@ -222,35 +222,47 @@ static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter,
   return converter.convertType(firType);
 }
 
-// FIR Op specific conversion for TargetAllocMemOp
-struct TargetAllocMemOpConversion
-    : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
-  using OpenMPFIROpConversion::OpenMPFIROpConversion;
+// FIR Op specific conversion for allocation operations
+template <typename T>
+struct AllocMemOpConversion : public OpenMPFIROpConversion<T> {
+  using OpenMPFIROpConversion<T>::OpenMPFIROpConversion;
 
   llvm::LogicalResult
-  matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
+  matchAndRewrite(T allocmemOp,
+                  typename OpenMPFIROpConversion<T>::OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Type heapTy = allocmemOp.getAllocatedType();
     mlir::Location loc = allocmemOp.getLoc();
-    auto ity = lowerTy().indexType();
+    auto ity = OpenMPFIROpConversion<T>::lowerTy().indexType();
     mlir::Type dataTy = fir::unwrapRefType(heapTy);
-    mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy);
+    mlir::Type llvmObjectTy =
+        convertObjectType(OpenMPFIROpConversion<T>::lowerTy(), dataTy);
     if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy)))
-      TODO(loc, "omp.target_allocmem codegen of derived type with length "
-                "parameters");
+      TODO(loc, allocmemOp->getName().getStringRef() +
+                    " codegen of derived type with length parameters");
     mlir::Value size = fir::computeElementDistance(
-        loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
+        loc, llvmObjectTy, ity, rewriter,
+        OpenMPFIROpConversion<T>::lowerTy().getDataLayout());
     if (auto scaleSize = fir::genAllocationScaleSize(
             loc, allocmemOp.getInType(), ity, rewriter))
       size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
-    for (mlir::Value opnd : adaptor.getOperands().drop_front())
+    for (mlir::Value opnd : adaptor.getTypeparams())
+      size = rewriter.create<mlir::LLVM::MulOp>(
+          loc, ity, size,
+          integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
+                      opnd));
+    for (mlir::Value opnd : adaptor.getShape())
       size = rewriter.create<mlir::LLVM::MulOp>(
-          loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
-    auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
+          loc, ity, size,
+          integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter, ity,
+                      opnd));
+    auto mallocTyWidth =
+        OpenMPFIROpConversion<T>::lowerTy().getIndexTypeBitwidth();
     auto mallocTy =
         mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
     if (mallocTyWidth != ity.getIntOrFloatBitWidth())
-      size = integerCast(lowerTy(), loc, rewriter, mallocTy, size);
+      size = integerCast(OpenMPFIROpConversion<T>::lowerTy(), loc, rewriter,
+                         mallocTy, size);
     rewriter.modifyOpInPlace(allocmemOp, [&]() {
       allocmemOp.setInType(rewriter.getI8Type());
       allocmemOp.getTypeparamsMutable().clear();
@@ -265,5 +277,6 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
   patterns.add<MapInfoOpConversion>(converter);
   patterns.add<PrivateClauseOpConversion>(converter);
-  patterns.add<TargetAllocMemOpConversion>(converter);
+  patterns.add<AllocMemOpConversion<mlir::omp::TargetAllocMemOp>,
+               AllocMemOpConversion<mlir::omp::AllocSharedMemOp>>(converter);
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 02d61c1a3626a..d8e5f8cf5a45e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2950,6 +2950,17 @@ class OpenMPIRBuilder {
   LLVM_ABI CallInst *createOMPFree(const LocationDescription &Loc, Value *Addr,
                                    Value *Allocator, std::string Name = "");
 
+  /// Create a runtime call for kmpc_alloc_shared.
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param Size Size of allocated memory space.
+  /// \param Name Name of call Instruction.
+  ///
+  /// \returns CallInst to the kmpc_alloc_shared call.
+  LLVM_ABI CallInst *createOMPAllocShared(const LocationDescription &Loc,
+                                          Value *Size,
+                                          const Twine &Name = Twine(""));
+
   /// Create a runtime call for kmpc_alloc_shared.
   ///
   /// \param Loc The insert and source location description.
@@ -2961,6 +2972,18 @@ class OpenMPIRBuilder {
                                           Type *VarType,
                                           const Twine &Name = Twine(""));
 
+  /// Create a runtime call for kmpc_free_shared.
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param Addr Value obtained from the corresponding kmpc_alloc_shared call.
+  /// \param Size Size of allocated memory space.
+  /// \param Name Name of call Instruction.
+  ///
+  /// \returns CallInst to the kmpc_free_shared call.
+  LLVM_ABI CallInst *createOMPFreeShared(const LocationDescription &Loc,
+                                         Value *Addr, Value *Size,
+                                         const Twine &Name = Twine(""));
+
   /// Create a runtime call for kmpc_free_shared.
   ///
   /// \param Loc The insert and source location description.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index bd483aa2c5e02..a18db939b5876 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6855,32 +6855,45 @@ CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
 }
 
 CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
-                                                Type *VarType,
+                                                Value *Size,
                                                 const Twine &Name) {
   IRBuilder<>::InsertPointGuard IPG(Builder);
   updateToLocation(Loc);
 
-  const DataLayout &DL = M.getDataLayout();
-  Value *Args[] = {Builder.getInt64(DL.getTypeStoreSize(VarType))};
+  Value *Args[] = {Size};
   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc_shared);
   CallInst *Call = Builder.CreateCall(Fn, Args, Name);
-  Call->addRetAttr(
-      Attribute::getWithAlignment(M.getContext(), DL.getPrefTypeAlign(Int64)));
+  Call->addRetAttr(Attribute::getWithAlignment(
+      M.getContext(), M.getDataLayout().getPrefTypeAlign(Int64)));
   return Call;
 }
 
+CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc,
+                                                Type *VarType,
+                                                const Twine &Name) {
+  return createOMPAllocShared(
+      Loc, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType)), Name);
+}
+
 CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
-                                               Value *Addr, Type *VarType,
+                                               Value *Addr, Value *Size,
                                                const Twine &Name) {
   IRBuilder<>::InsertPointGuard IPG(Builder);
   updateToLocation(Loc);
 
-  Value *Args[] = {
-      Addr, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType))};
+  Value *Args[] = {Addr, Size};
   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free_shared);
   return Builder.CreateCall(Fn, Args, Name);
 }
 
+CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc,
+                                               Value *Addr, Type *VarType,
+                                               const Twine &Name) {
+  return createOMPFreeShared(
+      Loc, Addr, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType)),
+      Name);
+}
+
 CallInst *OpenMPIRBuilder::createOMPInteropInit(
     const LocationDescription &Loc, Value *InteropVar,
     omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8b206f58c7733..fa037c2ff9496 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -2202,6 +2202,68 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem",
   Arg<I64, "", [MemFree]>:$heapref
   );
   let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// AllocSharedMemOp
+//===----------------------------------------------------------------------===//
+
+def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [
+    AttrSizedOperandSegments
+  ], clauses = [
+    OpenMP_HeapAllocClause
+  ]> {
+  let summary = "allocate storage on shared memory for an object of a given type";
+
+  let description = [{
+    Allocates memory shared across threads of a team for an object of the given
+    type. Returns a pointer representing the allocated memory. The memory is
+    uninitialized after allocation. Operations must be paired with
+    `omp.free_shared` to avoid memory leaks.
+
+    ```mlir
+      // Allocate a static 3x3 integer vector.
+      %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+      // ...
+      omp.free_shared_mem %ptr_shared : !llvm.ptr
+    ```
+  }] # clausesDescription;
+
+  let results = (outs OpenMP_PointerLikeType);
+  let assemblyFormat = clausesAssemblyFormat # " attr-dict `:` type(results)";
+}
+
+//===----------------------------------------------------------------------===//
+// FreeSharedMemOp
+//===----------------------------------------------------------------------===//
+
+def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> {
+  let summary = "free shared memory";
+
+  let description = [{
+    Deallocates shared memory that was previously allocated by an
+    `omp.alloc_shared_mem` operation. After this operation, the deallocated
+    memory is in an undefined state and should not be accessed.
+    It is crucial to ensure that all accesses to the memory region are completed
+    before `omp.alloc_shared_mem` is called to avoid undefined behavior.
+
+    ```mlir
+      // Example of allocating and freeing shared memory.
+      %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr
+      // ...
+      omp.free_shared_mem %ptr_shared : !llvm.ptr
+    ```
+
+    The `heapref` operand represents the pointer to shared memory to be
+    deallocated, previously returned by `omp.alloc_shared_mem`.
+  }];
+
+  let arguments = (ins
+    Arg<OpenMP_PointerLikeType, "", [MemFree]>:$heapref
+  );
+  let assemblyFormat = "$heapref attr-dict `:` type($heapref)";
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index fabb1b8c173a2..3b48dce4b7989 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4161,6 +4161,28 @@ LogicalResult AllocateDirOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TargetFreeMemOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TargetFreeMemOp::verify() {
+  return getHeapref().getDefiningOp<TargetAllocMemOp>()
+             ? success()
+             : emitOpError() << "'heapref' operand must be defined by an "
+                                "'omp.target_allocmem' op";
+}
+
+//===----------------------------------------------------------------------===//
+// FreeSharedMemOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult FreeSharedMemOp::verify() {
+  return getHeapref().getDefiningOp<AllocSharedMemOp>()
+             ? success()
+             : emitOpError() << "'heapref' operand must be defined by an "
+                                "'omp.alloc_shared_memory' op";
+}
+
 //===----------------------------------------------------------------------===//
 // WorkdistributeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 80e052105dc4c..3accca891ba9c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6104,11 +6104,9 @@ static bool isTargetDeviceOp(Operation *op) {
   // by taking it in as an operand, so we must always lower these in
   // some manner or result in an ICE (whether they end up in a no-op
   // or otherwise).
-  if (mlir::isa<omp::ThreadprivateOp>(op))
-    return true;
-
-  if (mlir::isa<omp::TargetAllocMemOp>(op) ||
-      mlir::isa<omp::TargetFreeMemOp>(op))
+  if (mlir::isa<omp::ThreadprivateOp, omp::TargetAllocMemOp,
+                omp::TargetFreeMemOp, omp::AllocSharedMemOp,
+                omp::FreeSharedMemOp>(op))
     return true;
 
   if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
@@ -6135,6 +6133,21 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
   return func;
 }
 
+static llvm::Value *
+getAllocationSize(llvm::IRBuilderBase &builder,
+                  LLVM::ModuleTranslation &moduleTranslation, Type allocatedTy,
+                  OperandRange typeparams, OperandRange shape) {
+  llvm::DataLayout dataLayout =
+      moduleTranslation.getLLVMModule()->getDataLayout();
+  llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy);
+  llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
+  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+  for (auto typeParam : typeparams)
+    allocSize =
+        builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+  return allocSize;
+}
+
 static LogicalResult
 convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
                         LLVM::ModuleTranslation &moduleTranslation) {
@@ -6149,14 +6162,9 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   mlir::Value deviceNum = allocMemOp.getDevice();
   llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
   // Get the allocation size.
-  llvm::DataLayout dataLayout = llvmModule->getDataLayout();
-  mlir::Type heapTy = allocMemOp.getAllocatedType();
-  llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
-  llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
-  llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
-  for (auto typeParam : allocMemOp.getTypeparams())
-    allocSize =
-        builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+  llvm::Value *allocSize = getAllocationSize(
+      builder, moduleTranslation, allocMemOp.getAllocatedType(),
+      allocMemOp.getTypeparams(), allocMemOp.getShape());
   // Create call to "omp_target_alloc" with the args as translated llvm values.
   llvm::CallInst *call =
       builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
@@ -6167,6 +6175,19 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   return success();
 }
 
+static LogicalResult
+convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp,
+                        llvm::IRBuilderBase &builder,
+                        LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  llvm::Value *size = getAllocationSize(
+      builder, moduleTranslation, allocMemOp.getAllocatedType(),
+      allocMemOp.getTypeparams(), allocMemOp.getShape());
+  moduleTranslation.mapValue(allocMemOp.getResult(),
+                             ompBuilder->createOMPAllocShared(builder, size));
+  return success();
+}
+
 static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
                                         llvm::Module *llvmModule) {
   llvm::Type *ptrTy = builder.getPtrTy(0);
@@ -6202,6 +6223,21 @@ convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
   return success();
 }
 
+static LogicalResult
+convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp,
+                       llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  auto allocMemOp =
+      freeMemOp.getHeapref().getDefiningOp<omp::AllocSharedMemOp>();
+  llvm::Value *size = getAllocationSize(
+      builder, moduleTranslation, allocMemOp.getAllocatedType(),
+      allocMemOp.getTypeparams(), allocMemOp.getShape());
+  ompBuilder->createOMPFreeShared(
+      builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size);
+  return success();
+}
+
 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
 /// OpenMP runtime calls).
 static LogicalResult
@@ -6382,6 +6418,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
           .Case([&](omp::TargetFreeMemOp) {
             return convertTargetFreeMemOp(*op, builder, moduleTranslation);
           })
+          .Case([&](omp::AllocSharedMemOp op) {
+            return convertAllocSharedMemOp(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::FreeSharedMemOp op) {
+            return convertFreeSharedMemOp(op, builder, moduleTranslation);
+          })
           .Default([&](Operation *inst) {
             return inst->emitError()
                    << "not yet implemented: " << inst->getName();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 0cc4b522db466..9f28172161fa8 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -3153,3 +3153,31 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
   %0 = omp.target_allocmem %device : i32, i64 {bindc_name=2}
   return
 }
+
+// -----
+func.func @target_freemem_invalid_ptr(%device : i32, %ptr : i64) -> () {
+  // expected-error @below {{op 'heapref' operand must be defined by an 'omp.target_allocmem' op}}
+  omp.target_freemem %device, %ptr : i32, i64
+  return
+}
+
+// -----
+func.func @alloc_shared_mem_invalid_uniq_name() -> () {
+  // expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}}
+  %0 = omp.alloc_shared_mem i64 {uniq_name=2}
+  return
+}
+
+// -----
+func.func @alloc_shared_mem_invalid_bindc_name() -> () {
+  // expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}}
+  %0 = omp.alloc_shared_mem i64 {bindc_name=2}
+  return
+}
+
+// -----
+func.func @free_shared_mem_invalid_ptr(%ptr : !llvm.ptr) -> () {
+  // expected-error @below {{op 'heapref' operand must be defined by an 'omp.alloc_shared_memory' op}}
+  omp.free_shared_mem %ptr : !llvm.ptr
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 9e7287178ff66..55e6d77857972 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -3339,9 +3339,36 @@ func.func @omp_target_allocmem(%device: i32, %x: index, %y: index, %z: i32) {
 }
 
 // CHECK-LABEL: func.func @omp_target_freemem(
-// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) {
-func.func @omp_target_freemem(%device : i32, %ptr : i64) {
+// CHECK-SAME: %[[DEVICE:.*]]: i32) {
+func.func @omp_target_freemem(%device : i32) {
+  // CHECK: %[[PTR:.*]] = omp.ta...
[truncated]

Copy link
Member

@Meinersbur Meinersbur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got confusef by a bad ambiguity between and llvm_omp_target_alloc_shared (in the sense of "unified shared memory") and __kmpc_alloc_shared (called "contention group" memory in the OpenMP spec)

Nothing with this PR though, LGTM.

`omp.alloc_shared_mem` operation. After this operation, the deallocated
memory is in an undefined state and should not be accessed.
It is crucial to ensure that all accesses to the memory region are completed
before `omp.alloc_shared_mem` is called to avoid undefined behavior.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
before `omp.alloc_shared_mem` is called to avoid undefined behavior.
before `omp.free_shared_mem` is called to avoid undefined behavior.

The concept of alloc/free might not need such detailed explanation

llvm::DataLayout dataLayout =
moduleTranslation.getLLVMModule()->getDataLayout();
llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy);
llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getTypeAllocSize() would be better to use here since it considers any alignment aswell.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang:openmp OpenMP related changes to Clang flang:codegen flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:llvm mlir:openmp mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants