Skip to content
5 changes: 2 additions & 3 deletions mlir/include/mlir/Dialect/OpenACC/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
mlir::acc::DataOp, mlir::acc::DeclareOp
mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp, \
mlir::acc::DeclareExitOp
mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp
#define ACC_DATA_CONSTRUCT_OPS \
ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS \
Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
collectVars(op.getDataClauseOperands(), values, hostToDevice);
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
!std::is_same_v<Op, acc::DataOp> &&
!std::is_same_v<Op, acc::DeclareOp>) {
!std::is_same_v<Op, acc::DeclareOp> &&
!std::is_same_v<Op, acc::HostDataOp>) {
collectVars(op.getReductionOperands(), values, hostToDevice);
collectVars(op.getPrivateOperands(), values, hostToDevice);
collectVars(op.getFirstprivateOperands(), values, hostToDevice);
Expand Down Expand Up @@ -122,6 +123,8 @@ class LegalizeDataValuesInRegion
collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
} else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
} else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
} else {
llvm_unreachable("unsupported acc region op");
}
Expand Down
28 changes: 24 additions & 4 deletions mlir/test/Dialect/OpenACC/legalize-data.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func.func @test(%a: memref<10xf32>) {
return
}

// CHECK: func.func @test
// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
Expand Down Expand Up @@ -140,7 +140,7 @@ func.func @test(%a: memref<10xf32>) {
return
}

// CHECK: func.func @test
// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
Expand Down Expand Up @@ -178,7 +178,7 @@ func.func @test(%a: memref<10xf32>) {
return
}

// CHECK: func.func @test
// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
// CHECK: acc.parallel {
Expand Down Expand Up @@ -216,7 +216,7 @@ func.func @test(%a: memref<10xf32>) {
return
}

// CHECK: func.func @test
// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
Expand All @@ -226,3 +226,23 @@ func.func @test(%a: memref<10xf32>) {
// CHECK: }
// CHECK: acc.yield
// CHECK: }

// -----

func.func @test(%a: memref<10xf32>) {
%devptr = acc.use_device varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
acc.host_data dataOperands(%devptr : memref<10xf32>) {
func.call @foo(%a) : (memref<10xf32>) -> ()
acc.terminator
}
return
}
func.func private @foo(memref<10xf32>)

// CHECK-LABEL: func.func @test
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
// CHECK: %[[USE_DEVICE:.*]] = acc.use_device varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
// CHECK: acc.host_data dataOperands(%[[USE_DEVICE]] : memref<10xf32>) {
// DEVICE: func.call @foo(%[[USE_DEVICE]]) : (memref<10xf32>) -> ()
// CHECK: acc.terminator
// CHECK: }
Loading