Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions circle-mlir/circle-mlir/lib/dialect/mlir/CircleOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class CIR_VariadicTensorOf<list<Type> allowedRuntimeTypes,
Variadic<TensorOf<allowedOpTypes>>,
CIR_RuntimeType<Variadic<TensorOf<allowedRuntimeTypes>>>;

def CIR_I4 : I<4>;
def CIR_Int32Or64 : SignlessIntOfWidths<[32, 64]>;

def CIR_BoolTensor : CIR_TensorOf<[I1]>;
Expand Down Expand Up @@ -259,6 +260,8 @@ class CIR_Op<string mnemonic, list<Trait> traits = []> :

// Whether the Circle operator has options in the schema representation.
bit hasOptions = 0b0;
// Whether the Circle operator has options2 in the schema representation.
bit hasOptions2 = 0b0;

// Use to specify a custom options type for Circle operators where
// the option's name does not match the Cirlce operator's name.
Expand Down
85 changes: 72 additions & 13 deletions circle-mlir/circle-mlir/lib/tools/converter-gen/converter_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,28 @@ static inline bool IsLstmOp(const StringRef op_name) {
return op_name.take_back(6) == "LSTMOp";
}

static int HasOptions(const Record &def) {
if (def.getValueAsBit("hasOptions")) {
return 1;
}
if (def.getValueAsBit("hasOptions2")) {
return 2;
}
return 0;
}

static void EmitOptionBuilders(const RecordKeeper &record_keeper,
const std::vector<const Record *> &defs,
raw_ostream *ostream) {
raw_ostream &os = *ostream;

const auto attr_type = record_keeper.getClass("Attr");
for (const auto *def : defs) {
const int has_options = HasOptions(*def);
// Circle ops without options are skipped over.
if (!def->getValueAsBit("hasOptions")) continue;
if (!has_options) {
continue;
}

StringRef op_name = def->getName().drop_front(4); // Strip 'CIR_' prefix
std::string option_name = GetOperatorOptionName(*def);
Expand Down Expand Up @@ -204,7 +217,8 @@ static void EmitOperatorBuilders(const std::vector<const Record *> &defs,
// Build the FlatBuffer operator
os << " return circle::CreateOperator(\n"
" *fbb, opcode_index, inputs, outputs,\n";
if (def->getValueAsBit("hasOptions")) {
const int has_options = HasOptions(*def);
if (has_options == 1) {
auto option_name = GetOperatorOptionName(*def);
std::string circle_option_name =
option_name == "BasicLSTMOptions" ? "LSTMOptions" : option_name;
Expand All @@ -217,8 +231,26 @@ static void EmitOperatorBuilders(const std::vector<const Record *> &defs,
// used by custom or flex ops and those ops are handled manually.
os << " /*custom_options=*/0, "
<< "circle::CustomOptionsFormat_FLEXBUFFERS,\n"
<< " /*mutating_variable_inputs=*/0"
<< (has_intermediates ? ", intermediates" : "") << ");\n}\n\n";
<< " /*mutating_variable_inputs=*/0,"
<< (has_intermediates ? "intermediates" : "/*intermediates=*/0");

if (has_options == 2) {
os << ",\n"
<< " /*large_custom_options_offset=*/0,\n"
<< " /*large_custom_options_size=*/0";
os << ",\n";
const std::string option_name = GetOperatorOptionName(*def);
os << " circle::BuiltinOptions2_" << option_name << ", "
<< "Create" << option_name << "(tflOp, fbb).Union()";
} else {
os << ",\n"
<< " /*large_custom_options_offset=*/0,\n"
<< " /*large_custom_options_size=*/0";
os << ",\n";
os << " circle::BuiltinOptions2_NONE, /*builtin_options2=*/0";
}

os << ");\n}\n\n";
}
}

Expand Down Expand Up @@ -355,24 +387,43 @@ static void EmitBuildOperator(const std::vector<const Record *> &defs,

// Emit a function that converts a BuiltinOptionsUnion to a vector of attributes
// Signature:
// void mlir::BuiltinOptionsToAttributes(
// circle::BuiltinOptionsUnion op_union,
// void mlir::BuiltinOptions{id}ToAttributes(
// circle::BuiltinOptions{id}Union op_union,
// mlir::Builder builder,
// llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes);
static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper,
const std::vector<const Record *> &defs,
raw_ostream *ostream) {
//
// where id is an empty string if builtin_options_id is 1, or builtin_options_id
// otherwise.
static void EmitBuiltinOptionsToAttributes(
const RecordKeeper &record_keeper, const std::vector<const Record *> &defs,
raw_ostream *ostream, const int builtin_options_id) {
raw_ostream &os = *ostream;

const std::string builtin_options_suffix = [&] {
switch (builtin_options_id) {
case 1:
return "";
case 2:
return "2";
}
return "UnknownId";
}();

// Signature
os << "void mlir::BuiltinOptionsToAttributes("
"circle::BuiltinOptionsUnion op_union, "
os << "void mlir::BuiltinOptions" << builtin_options_suffix
<< "ToAttributes("
"circle::BuiltinOptions"
<< builtin_options_suffix
<< "Union op_union, "
"mlir::Builder builder, "
"llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes) {\n";

const auto attr_type = record_keeper.getClass("Attr");
for (const auto *def : defs) {
if (!def->getValueAsBit("hasOptions")) continue;
const int has_options = HasOptions(*def);
if (has_options != builtin_options_id) {
continue;
}
auto option_name = GetOperatorOptionName(*def);
// Basic LSTM and LSTM ops share the same option to attribute converter.
if (option_name == "BasicLSTMOptions") {
Expand Down Expand Up @@ -405,9 +456,14 @@ static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper,
os << " return;\n";
os << " }\n";
}
if (builtin_options_id == 2) {
os << " BuiltinOptions2ToAttributesManual(op_union, builder, "
"attributes);\n";
}
// Fallthrough case is no attributes
os << "}";
}

// The function below has a non-constant reference as that is required by LLVM's
// TableGenMain.
// NOLINTNEXTLINE
Expand Down Expand Up @@ -440,8 +496,11 @@ static bool OperatorWritersMain(raw_ostream &os, const RecordKeeper &records) {
os << "\n\n";
EmitBuildOperator(defs, &os);
os << "\n\n";
EmitBuiltinOptionsToAttributes(records, defs, &os);
EmitBuiltinOptionsToAttributes(records, defs, &os, /*builtin_options_id=*/1);
os << "\n\n";
// TODO support options2
//EmitBuiltinOptionsToAttributes(records, defs, &os, /*builtin_options_id=*/2);
//os << "\n\n";
EmitOperandNumbers(records, defs, &os);

return false;
Expand Down