File tree Expand file tree Collapse file tree 6 files changed +30
-9
lines changed Expand file tree Collapse file tree 6 files changed +30
-9
lines changed Original file line number Diff line number Diff line change @@ -215,7 +215,12 @@ def GenericCallOp : Toy_Op<"generic_call",
215215
216216 // The generic call operation takes a symbol reference attribute as the
217217 // callee, and inputs for the call.
218- let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
218+ let arguments = (ins
219+ FlatSymbolRefAttr:$callee,
220+ Variadic<F64Tensor>:$inputs,
221+ OptionalAttr<DictArrayAttr>:$arg_attrs,
222+ OptionalAttr<DictArrayAttr>:$res_attrs
223+ );
219224
220225 // The generic call operation returns a single value of TensorType.
221226 let results = (outs F64Tensor);
Original file line number Diff line number Diff line change @@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call",
214214
215215 // The generic call operation takes a symbol reference attribute as the
216216 // callee, and inputs for the call.
217- let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
217+ let arguments = (ins
218+ FlatSymbolRefAttr:$callee,
219+ Variadic<F64Tensor>:$inputs,
220+ OptionalAttr<DictArrayAttr>:$arg_attrs,
221+ OptionalAttr<DictArrayAttr>:$res_attrs
222+ );
218223
219224 // The generic call operation returns a single value of TensorType.
220225 let results = (outs F64Tensor);
Original file line number Diff line number Diff line change @@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call",
214214
215215 // The generic call operation takes a symbol reference attribute as the
216216 // callee, and inputs for the call.
217- let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
217+ let arguments = (ins
218+ FlatSymbolRefAttr:$callee,
219+ Variadic<F64Tensor>:$inputs,
220+ OptionalAttr<DictArrayAttr>:$arg_attrs,
221+ OptionalAttr<DictArrayAttr>:$res_attrs
222+ );
218223
219224 // The generic call operation returns a single value of TensorType.
220225 let results = (outs F64Tensor);
Original file line number Diff line number Diff line change @@ -237,7 +237,12 @@ def GenericCallOp : Toy_Op<"generic_call",
237237
238238 // The generic call operation takes a symbol reference attribute as the
239239 // callee, and inputs for the call.
240- let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<Toy_Type>:$inputs);
240+ let arguments = (ins
241+ FlatSymbolRefAttr:$callee,
242+ Variadic<Toy_Type>:$inputs,
243+ OptionalAttr<DictArrayAttr>:$arg_attrs,
244+ OptionalAttr<DictArrayAttr>:$res_attrs
245+ );
241246
242247 // The generic call operation returns a single value of TensorType or
243248 // StructType.
@@ -250,7 +255,8 @@ def GenericCallOp : Toy_Op<"generic_call",
250255
251256 // Add custom build methods for the generic call operation.
252257 let builders = [
253- OpBuilder<(ins "StringRef":$callee, "ArrayRef<Value>":$arguments)>
258+ OpBuilder<(ins "Type":$result_type, "StringRef":$callee,
259+ "ArrayRef<Value>":$arguments)>
254260 ];
255261}
256262
Original file line number Diff line number Diff line change @@ -350,9 +350,10 @@ void FuncOp::print(mlir::OpAsmPrinter &p) {
350350// ===----------------------------------------------------------------------===//
351351
352352void GenericCallOp::build (mlir::OpBuilder &builder, mlir::OperationState &state,
353- StringRef callee, ArrayRef<mlir::Value> arguments) {
353+ mlir::Type resultType, StringRef callee,
354+ ArrayRef<mlir::Value> arguments) {
354355 // Generic call always returns an unranked Tensor initially.
355- state.addTypes (UnrankedTensorType::get (builder. getF64Type ()) );
356+ state.addTypes (resultType );
356357 state.addOperands (arguments);
357358 state.addAttribute (" callee" ,
358359 mlir::SymbolRefAttr::get (builder.getContext (), callee));
Original file line number Diff line number Diff line change @@ -535,8 +535,7 @@ class MLIRGenImpl {
535535 }
536536 mlir::toy::FuncOp calledFunc = calledFuncIt->second ;
537537 return builder.create <GenericCallOp>(
538- location, calledFunc.getFunctionType ().getResult (0 ),
539- mlir::SymbolRefAttr::get (builder.getContext (), callee), operands);
538+ location, calledFunc.getFunctionType ().getResult (0 ), callee, operands);
540539 }
541540
542541 // / Emit a print expression. It emits specific operations for two builtins:
You can’t perform that action at this time.
0 commit comments