Skip to content

Commit 9bab7ec

Browse files
committed
[flang] Fix parsing and printing.
1 parent 1656f33 commit 9bab7ec

File tree

4 files changed

+145
-14
lines changed

4 files changed

+145
-14
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,8 @@ def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem",
528528
be paired with `omp_target_freemem` operations to avoid memory leaks.
529529

530530
```
531-
%0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap<!fir.array<?xf32>>
531+
%device = arith.constant 0 : i32
532+
%1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32>
532533
```
533534
}];
534535

@@ -542,6 +543,9 @@ def fir_OmpTargetAllocMemOp : fir_Op<"omp_target_allocmem",
542543
);
543544
let results = (outs fir_HeapType);
544545

546+
let hasCustomAssemblyFormat = 1;
547+
let hasVerifier = 1;
548+
545549
let extraClassDeclaration = [{
546550
mlir::Type getAllocatedType();
547551
bool hasLenParams() { return !getTypeparams().empty(); }
@@ -563,16 +567,17 @@ def fir_OmpTargetFreeMemOp : fir_Op<"omp_target_freemem",
563567
The memory object that is deallocated is placed in an undefined state
564568
after `fir.omp_target_freemem`.
565569
```
566-
%0 = "fir.omp_target_allocmem"(%device, %type) : (i32, index) -> !fir.heap<!fir.array<?xf32>>
567-
...
568-
"fir.omp_target_freemem"(%device, %0) : (i32, !fir.heap<!fir.array<?xf32>>) -> ()
570+
%device = arith.constant 0 : i32
571+
%1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32>
572+
fir.omp_target_freemem %device, %1 : i32, !fir.heap<!fir.array<?xf32>>
569573
```
570574
}];
571575

572576
let arguments = (ins
573577
Arg<AnyIntegerType, "", [MemFree]>:$device,
574578
Arg<fir_HeapType, "", [MemFree]>:$heapref
575579
);
580+
let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))";
576581
}
577582

578583
//===----------------------------------------------------------------------===//

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,24 +106,38 @@ static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) {
106106
return false;
107107
}
108108

109-
/// Parser shared by Alloca and Allocmem
110-
///
109+
/// Parser shared by Alloca, Allocmem and OmpTargetAllocmem
110+
/// boolean flag isTargetOp is used to identify omp_target_allocmem
111111
/// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type
112112
/// ( `(` $typeparams `)` )? ( `,` $shape )?
113113
/// attr-dict-without-keyword
114+
/// operation ::= %res = (`fir.omp_target_alloca`) $device : devicetype,
115+
/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
116+
/// attr-dict-without-keyword
114117
template <typename FN>
115-
static mlir::ParseResult parseAllocatableOp(FN wrapResultType,
116-
mlir::OpAsmParser &parser,
117-
mlir::OperationState &result) {
118+
static mlir::ParseResult
119+
parseAllocatableOp(FN wrapResultType, mlir::OpAsmParser &parser,
120+
mlir::OperationState &result, bool isTargetOp = false) {
121+
auto &builder = parser.getBuilder();
122+
bool hasOperands = false;
123+
std::int32_t typeparamsSize = 0;
124+
// Parse device number as a new operand
125+
if (isTargetOp) {
126+
mlir::OpAsmParser::UnresolvedOperand deviceOperand;
127+
mlir::Type deviceType;
128+
if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
129+
return mlir::failure();
130+
if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
131+
return mlir::failure();
132+
if (parser.parseComma())
133+
return mlir::failure();
134+
}
118135
mlir::Type intype;
119136
if (parser.parseType(intype))
120137
return mlir::failure();
121-
auto &builder = parser.getBuilder();
122138
result.addAttribute("in_type", mlir::TypeAttr::get(intype));
123139
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
124140
llvm::SmallVector<mlir::Type> typeVec;
125-
bool hasOperands = false;
126-
std::int32_t typeparamsSize = 0;
127141
if (!parser.parseOptionalLParen()) {
128142
// parse the LEN params of the derived type. (<params> : <types>)
129143
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
@@ -147,13 +161,19 @@ static mlir::ParseResult parseAllocatableOp(FN wrapResultType,
147161
parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
148162
result.operands))
149163
return mlir::failure();
164+
150165
mlir::Type restype = wrapResultType(intype);
151166
if (!restype) {
152167
parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
153168
return mlir::failure();
154169
}
155-
result.addAttribute("operandSegmentSizes", builder.getDenseI32ArrayAttr(
156-
{typeparamsSize, shapeSize}));
170+
llvm::SmallVector<std::int32_t> segmentSizes;
171+
if (isTargetOp)
172+
segmentSizes.push_back(1);
173+
segmentSizes.push_back(typeparamsSize);
174+
segmentSizes.push_back(shapeSize);
175+
result.addAttribute("operandSegmentSizes",
176+
builder.getDenseI32ArrayAttr(segmentSizes));
157177
if (parser.parseOptionalAttrDict(result.attributes) ||
158178
parser.addTypeToList(restype, result.types))
159179
return mlir::failure();
@@ -385,6 +405,56 @@ llvm::LogicalResult fir::AllocMemOp::verify() {
385405
return mlir::success();
386406
}
387407

408+
//===----------------------------------------------------------------------===//
409+
// OmpTargetAllocMemOp
410+
//===----------------------------------------------------------------------===//
411+
412+
mlir::Type fir::OmpTargetAllocMemOp::getAllocatedType() {
413+
return mlir::cast<fir::HeapType>(getType()).getEleTy();
414+
}
415+
416+
mlir::Type fir::OmpTargetAllocMemOp::getRefTy(mlir::Type ty) {
417+
return fir::HeapType::get(ty);
418+
}
419+
420+
mlir::ParseResult
421+
fir::OmpTargetAllocMemOp::parse(mlir::OpAsmParser &parser,
422+
mlir::OperationState &result) {
423+
return parseAllocatableOp(wrapAllocMemResultType, parser, result, true);
424+
}
425+
426+
void fir::OmpTargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
427+
p << " ";
428+
p.printOperand(getDevice());
429+
p << " : ";
430+
p << getDevice().getType();
431+
p << ", ";
432+
p << getInType();
433+
if (!getTypeparams().empty()) {
434+
p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
435+
}
436+
for (auto sh : getShape()) {
437+
p << ", ";
438+
p.printOperand(sh);
439+
}
440+
p.printOptionalAttrDict((*this)->getAttrs(),
441+
{"in_type", "operandSegmentSizes"});
442+
}
443+
444+
llvm::LogicalResult fir::OmpTargetAllocMemOp::verify() {
445+
llvm::SmallVector<llvm::StringRef> visited;
446+
if (verifyInType(getInType(), visited, numShapeOperands()))
447+
return emitOpError("invalid type for allocation");
448+
if (verifyTypeParamCount(getInType(), numLenParams()))
449+
return emitOpError("LEN params do not correspond to type");
450+
mlir::Type outType = getType();
451+
if (!mlir::dyn_cast<fir::HeapType>(outType))
452+
return emitOpError("must be a !fir.heap type");
453+
if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType)))
454+
return emitOpError("cannot allocate !fir.box of unknown rank or type");
455+
return mlir::success();
456+
}
457+
388458
//===----------------------------------------------------------------------===//
389459
// ArrayCoorOp
390460
//===----------------------------------------------------------------------===//
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s
2+
3+
// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_nonchar(
4+
// CHECK: call ptr @omp_target_alloc(i64 36, i32 0)
5+
func.func @omp_target_allocmem_array_of_nonchar() -> !fir.heap<!fir.array<3x3xi32>> {
6+
%device = arith.constant 0 : i32
7+
%1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32>
8+
return %1 : !fir.heap<!fir.array<3x3xi32>>
9+
}
10+
11+
// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_char(
12+
// CHECK: call ptr @omp_target_alloc(i64 90, i32 0)
13+
func.func @omp_target_allocmem_array_of_char() -> !fir.heap<!fir.array<3x3x!fir.char<1,10>>> {
14+
%device = arith.constant 0 : i32
15+
%1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>>
16+
return %1 : !fir.heap<!fir.array<3x3x!fir.char<1,10>>>
17+
}
18+
19+
// CHECK-LABEL: define ptr @omp_target_allocmem_array_of_dynchar(
20+
// CHECK-SAME: i32 %[[len:.*]])
21+
// CHECK: %[[mul1:.*]] = sext i32 %[[len]] to i64
22+
// CHECK: %[[mul2:.*]] = mul i64 9, %[[mul1]]
23+
// CHECK: call ptr @omp_target_alloc(i64 %[[mul2]], i32 0)
24+
func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> !fir.heap<!fir.array<3x3x!fir.char<1,?>>> {
25+
%device = arith.constant 0 : i32
26+
%1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32)
27+
return %1 : !fir.heap<!fir.array<3x3x!fir.char<1,?>>>
28+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s
2+
3+
// CHECK-LABEL: define void @omp_target_allocmem_array_of_nonchar(
4+
// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0)
5+
func.func @omp_target_allocmem_array_of_nonchar() -> () {
6+
%device = arith.constant 0 : i32
7+
%1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3xi32>
8+
fir.omp_target_freemem %device, %1 : i32, !fir.heap<!fir.array<3x3xi32>>
9+
return
10+
}
11+
12+
// CHECK-LABEL: define void @omp_target_allocmem_array_of_char(
13+
// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0)
14+
func.func @omp_target_allocmem_array_of_char() -> () {
15+
%device = arith.constant 0 : i32
16+
%1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,10>>
17+
fir.omp_target_freemem %device, %1 : i32, !fir.heap<!fir.array<3x3x!fir.char<1,10>>>
18+
return
19+
}
20+
21+
// CHECK-LABEL: define void @omp_target_allocmem_array_of_dynchar(
22+
// CHECK: call void @omp_target_free(ptr {{.*}}, i32 0)
23+
func.func @omp_target_allocmem_array_of_dynchar(%l: i32) -> () {
24+
%device = arith.constant 0 : i32
25+
%1 = fir.omp_target_allocmem %device : i32, !fir.array<3x3x!fir.char<1,?>>(%l : i32)
26+
fir.omp_target_freemem %device, %1 : i32, !fir.heap<!fir.array<3x3x!fir.char<1,?>>>
27+
return
28+
}

0 commit comments

Comments
 (0)