Skip to content

Commit 186c911

Browse files
committed
Merge branch 'mlir-call-attrs-iface' into mlir-call-attrs-llvm
2 parents 854e43c + 43cd704 commit 186c911

File tree

24 files changed

+380
-329
lines changed

24 files changed

+380
-329
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
207207
I32:$block_z,
208208
Optional<I32>:$bytes,
209209
Optional<I32>:$stream,
210-
Variadic<AnyType>:$args
210+
Variadic<AnyType>:$args,
211+
OptionalAttr<DictArrayAttr>:$arg_attrs,
212+
OptionalAttr<DictArrayAttr>:$res_attrs
211213
);
212214

213215
let assemblyFormat = [{

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2432,6 +2432,8 @@ def fir_CallOp : fir_Op<"call",
24322432
let arguments = (ins
24332433
OptionalAttr<SymbolRefAttr>:$callee,
24342434
Variadic<AnyType>:$args,
2435+
OptionalAttr<DictArrayAttr>:$arg_attrs,
2436+
OptionalAttr<DictArrayAttr>:$res_attrs,
24352437
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs,
24362438
DefaultValuedAttr<Arith_FastMathAttr,
24372439
"::mlir::arith::FastMathFlags::none">:$fastmath
@@ -2518,6 +2520,8 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
25182520
fir_ClassType:$object,
25192521
Variadic<AnyType>:$args,
25202522
OptionalAttr<I32Attr>:$pass_arg_pos,
2523+
OptionalAttr<DictArrayAttr>:$arg_attrs,
2524+
OptionalAttr<DictArrayAttr>:$res_attrs,
25212525
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs
25222526
);
25232527

flang/lib/Lower/ConvertCall.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ Fortran::lower::genCallOpAndResult(
594594

595595
builder.create<cuf::KernelLaunchOp>(
596596
loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, grid_z,
597-
block_x, block_y, block_z, bytes, stream, operands);
597+
block_x, block_y, block_z, bytes, stream, operands,
598+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
598599
callNumResults = 0;
599600
} else if (caller.requireDispatchCall()) {
600601
// Procedure call requiring a dynamic dispatch. Call is created with
@@ -621,7 +622,8 @@ Fortran::lower::genCallOpAndResult(
621622
dispatch = builder.create<fir::DispatchOp>(
622623
loc, funcType.getResults(), builder.getStringAttr(procName),
623624
caller.getInputs()[*passArg], operands,
624-
builder.getI32IntegerAttr(*passArg), procAttrs);
625+
builder.getI32IntegerAttr(*passArg), /*arg_attrs=*/nullptr,
626+
/*res_attrs=*/nullptr, procAttrs);
625627
} else {
626628
// NOPASS
627629
const Fortran::evaluate::Component *component =
@@ -636,15 +638,17 @@ Fortran::lower::genCallOpAndResult(
636638
passObject = builder.create<fir::LoadOp>(loc, passObject);
637639
dispatch = builder.create<fir::DispatchOp>(
638640
loc, funcType.getResults(), builder.getStringAttr(procName),
639-
passObject, operands, nullptr, procAttrs);
641+
passObject, operands, nullptr, /*arg_attrs=*/nullptr,
642+
/*res_attrs=*/nullptr, procAttrs);
640643
}
641644
callNumResults = dispatch.getNumResults();
642645
if (callNumResults != 0)
643646
callResult = dispatch.getResult(0);
644647
} else {
645648
// Standard procedure call with fir.call.
646649
auto call = builder.create<fir::CallOp>(
647-
loc, funcType.getResults(), funcSymbolAttr, operands, procAttrs);
650+
loc, funcType.getResults(), funcSymbolAttr, operands,
651+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs);
648652

649653
callNumResults = call.getNumResults();
650654
if (callNumResults != 0)

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
518518
newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
519519

520520
llvm::SmallVector<mlir::Value, 1> newCallResults;
521+
// TODO propagate/update call argument and result attributes.
521522
if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
522523
auto newCall = rewriter->create<A>(
523524
loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
@@ -557,6 +558,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
557558
loc, newResTys, rewriter->getStringAttr(callOp.getMethod()),
558559
callOp.getOperands()[0], newOpers,
559560
rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift),
561+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
560562
callOp.getProcedureAttrsAttr());
561563
if (wrap)
562564
newCallResults.push_back((*wrap)(dispatchOp.getOperation()));

flang/lib/Optimizer/Transforms/AbstractResult.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
147147
newResultTypes.emplace_back(getVoidPtrType(result.getContext()));
148148

149149
Op newOp;
150+
// TODO: propagate argument and result attributes (need to be shifted).
150151
// fir::CallOp specific handling.
151152
if constexpr (std::is_same_v<Op, fir::CallOp>) {
152153
if (op.getCallee()) {
@@ -189,9 +190,11 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
189190
if (op.getPassArgPos())
190191
passArgPos =
191192
rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift);
193+
// TODO: propagate argument and result attributes (need to be shifted).
192194
newOp = rewriter.create<fir::DispatchOp>(
193195
loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
194196
op.getOperands()[0], newOperands, passArgPos,
197+
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
195198
op.getProcedureAttrsAttr());
196199
}
197200

flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,9 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> {
205205
// Make the call.
206206
llvm::SmallVector<mlir::Value> args{funcPtr};
207207
args.append(dispatch.getArgs().begin(), dispatch.getArgs().end());
208-
rewriter.replaceOpWithNewOp<fir::CallOp>(dispatch, resTypes, nullptr, args,
209-
dispatch.getProcedureAttrsAttr());
208+
rewriter.replaceOpWithNewOp<fir::CallOp>(
209+
dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(),
210+
dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr());
210211
return mlir::success();
211212
}
212213

mlir/docs/Interfaces.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -753,22 +753,25 @@ interface section goes as follows:
753753
- (`C++ class` -- `ODS class`(if applicable))
754754

755755
##### CallInterfaces
756-
* `OpWithArgumentAttributesInterface` - Used to represent operations that may
757-
carry argument and result attributes. It is inherited by both
758-
CallOpInterface and CallableOpInterface.
756+
* `CallOpInterface` - Used to represent operations like 'call'
757+
- `CallInterfaceCallable getCallableForCallee()`
758+
- `void setCalleeFromCallable(CallInterfaceCallable)`
759759
- `ArrayAttr getArgAttrsAttr()`
760760
- `ArrayAttr getResAttrsAttr()`
761761
- `void setArgAttrsAttr(ArrayAttr)`
762762
- `void setResAttrsAttr(ArrayAttr)`
763763
- `Attribute removeArgAttrsAttr()`
764764
- `Attribute removeResAttrsAttr()`
765-
* `CallOpInterface` - Used to represent operations like 'call'
766-
- `CallInterfaceCallable getCallableForCallee()`
767-
- `void setCalleeFromCallable(CallInterfaceCallable)`
768765
* `CallableOpInterface` - Used to represent the target callee of call.
769766
- `Region * getCallableRegion()`
770767
- `ArrayRef<Type> getArgumentTypes()`
771768
- `ArrayRef<Type> getResultsTypes()`
769+
- `ArrayAttr getArgAttrsAttr()`
770+
- `ArrayAttr getResAttrsAttr()`
771+
- `void setArgAttrsAttr(ArrayAttr)`
772+
- `void setResAttrsAttr(ArrayAttr)`
773+
- `Attribute removeArgAttrsAttr()`
774+
- `Attribute removeResAttrsAttr()`
772775

773776
##### RegionKindInterfaces
774777

mlir/include/mlir/Dialect/Async/IR/AsyncOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,13 @@ def Async_CallOp : Async_Op<"call",
208208
```
209209
}];
210210

211-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
211+
let arguments = (ins
212+
FlatSymbolRefAttr:$callee,
213+
Variadic<AnyType>:$operands,
214+
OptionalAttr<DictArrayAttr>:$arg_attrs,
215+
OptionalAttr<DictArrayAttr>:$res_attrs
216+
);
217+
212218
let results = (outs Variadic<Async_AnyValueOrTokenType>);
213219

214220
let builders = [

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,13 @@ def EmitC_CallOp : EmitC_Op<"call",
533533
%2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32
534534
```
535535
}];
536-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<EmitCType>:$operands);
536+
let arguments = (ins
537+
FlatSymbolRefAttr:$callee,
538+
Variadic<EmitCType>:$operands,
539+
OptionalAttr<DictArrayAttr>:$arg_attrs,
540+
OptionalAttr<DictArrayAttr>:$res_attrs
541+
);
542+
537543
let results = (outs Variadic<EmitCType>);
538544

539545
let builders = [

mlir/include/mlir/Dialect/Func/IR/FuncOps.td

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,14 @@ def CallOp : Func_Op<"call",
4949
```
5050
}];
5151

52-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands,
53-
UnitAttr:$no_inline);
52+
let arguments = (ins
53+
FlatSymbolRefAttr:$callee,
54+
Variadic<AnyType>:$operands,
55+
OptionalAttr<DictArrayAttr>:$arg_attrs,
56+
OptionalAttr<DictArrayAttr>:$res_attrs,
57+
UnitAttr:$no_inline
58+
);
59+
5460
let results = (outs Variadic<AnyType>);
5561

5662
let builders = [
@@ -73,6 +79,18 @@ def CallOp : Func_Op<"call",
7379
CArg<"ValueRange", "{}">:$operands), [{
7480
build($_builder, $_state, StringAttr::get($_builder.getContext(), callee),
7581
results, operands);
82+
}]>,
83+
OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
84+
CArg<"ValueRange", "{}">:$operands), [{
85+
build($_builder, $_state, callee, results, operands);
86+
}]>,
87+
OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
88+
CArg<"ValueRange", "{}">:$operands), [{
89+
build($_builder, $_state, callee, results, operands);
90+
}]>,
91+
OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
92+
CArg<"ValueRange", "{}">:$operands), [{
93+
build($_builder, $_state, callee, results, operands);
7694
}]>];
7795

7896
let extraClassDeclaration = [{
@@ -136,8 +154,13 @@ def CallIndirectOp : Func_Op<"call_indirect", [
136154
```
137155
}];
138156

139-
let arguments = (ins FunctionType:$callee,
140-
Variadic<AnyType>:$callee_operands);
157+
let arguments = (ins
158+
FunctionType:$callee,
159+
Variadic<AnyType>:$callee_operands,
160+
OptionalAttr<DictArrayAttr>:$arg_attrs,
161+
OptionalAttr<DictArrayAttr>:$res_attrs
162+
);
163+
141164
let results = (outs Variadic<AnyType>:$results);
142165

143166
let builders = [

0 commit comments

Comments
 (0)