Skip to content

Commit 8051236

Browse files
authored
[onert] Load BCQUnembedding operator (#16216)
* [onert] Load BCQUnembedding operator This commit updates CircleLoader to load `BCQUnembedding` custom operator. Signed-off-by: Seok Namkoong <seok9311@naver.com> * fix error handling --------- Signed-off-by: Seok Namkoong <seok9311@naver.com>
1 parent 8a6e1f3 commit 8051236

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

runtime/onert/core/src/loader/BaseLoader.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ template <typename LoaderDomain> class BaseLoader
101101
void loadSparsity(const Tensor *tensor, ir::TypeInfo &typeInfo);
102102
void loadOperationIO(const Operator *op, ir::OperandIndexSequence &inputs,
103103
ir::OperandIndexSequence &outputs);
104+
template <typename OpIR, typename... Args>
105+
const OpIR *loadOperationTo(const Operator *op, ir::Graph &subg, Args &&...args);
104106
// Create operations from Operator
105107
void loadOperation(const Operator *op, ir::Graph &subg);
106108
// Load Strides and Paddings from options to param
@@ -131,8 +133,6 @@ template <typename LoaderDomain> class BaseLoader
131133
std::unique_ptr<ir::Data> loadMetadata(const uint32_t buffer_idx);
132134
virtual std::unique_ptr<ir::Graph> loadSubgraph(const SubGraph *subg) = 0;
133135
// Operations
134-
template <typename OpIR, typename... Args>
135-
const OpIR *loadOperationTo(const Operator *op, ir::Graph &subg, Args &&...args);
136136

137137
void loadArgMinMax(const Operator *op, ir::Graph &subg, bool is_argmax);
138138
void loadBinaryArithmetic(const Operator *op, ir::Graph &subg,

runtime/onert/core/src/loader/CircleLoader.cc

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class CircleLoader final : public loader::BaseLoader<LoaderDomain>
7575
void loadRoPE(const Operator *op, ir::Graph &subg);
7676
void loadCall(const Operator *op, ir::Graph &subg);
7777
void loadRunModel(const Operator *op, ir::Graph &subg);
78+
void loadBCQUnembedding(const Operator *op, ir::Graph &subg);
7879
void loadCustom(const Operator *op, ir::Graph &subg);
7980

8081
public:
@@ -86,6 +87,7 @@ class CircleLoader final : public loader::BaseLoader<LoaderDomain>
8687
{
8788
case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
8889
case BuiltinOperator::BuiltinOperator_BCQ_FULLY_CONNECTED:
90+
case BuiltinOperator::BuiltinOperator_CUSTOM:
8991
case BuiltinOperator::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
9092
return true;
9193
default:
@@ -182,6 +184,9 @@ class CircleLoader final : public loader::BaseLoader<LoaderDomain>
182184
case circle::BuiltinOperator::BuiltinOperator_RUN_MODEL:
183185
loadRunModel(op, subg);
184186
return;
187+
case circle::BuiltinOperator::BuiltinOperator_CUSTOM:
188+
loadCustom(op, subg);
189+
return;
185190
default:
186191
BaseLoader::loadOperation(op, subg);
187192
return;
@@ -350,6 +355,76 @@ void CircleLoader::loadRunModel(const Operator *op, ir::Graph &subg)
350355
subg.addOperation(std::move(new_op));
351356
}
352357

358+
void CircleLoader::loadBCQUnembedding(const Operator *op, ir::Graph &subg)
359+
{
360+
ir::OperandIndexSequence inputs;
361+
ir::OperandIndexSequence outputs;
362+
363+
loadOperationIO(op, inputs, outputs);
364+
365+
ir::operation::BCQUnembedding::Param param;
366+
if (op->custom_options() == nullptr)
367+
{
368+
throw std::runtime_error{"BCQUnembedding: empty option"};
369+
}
370+
else
371+
{
372+
const auto attr_map = getCustomOpAttrMap(op);
373+
param.weights_hidden_size = attr_map["weights_hidden_size"].AsUInt32();
374+
param.lsh_type = attr_map["lsh_type"].AsString().str();
375+
param.lsh_choices = attr_map["lsh_choices"].AsInt32();
376+
}
377+
378+
const auto fbn = loadOperationTo<ir::operation::BCQUnembedding>(op, subg, param);
379+
380+
if (fbn->getInputs().size() != 5)
381+
{
382+
throw std::runtime_error{"BCQUnembedding: NYI input - only support five inputs"};
383+
}
384+
}
385+
386+
void CircleLoader::loadCustom(const Operator *op, ir::Graph &subg)
387+
{
388+
ir::OperandIndexSequence inputs;
389+
ir::OperandIndexSequence outputs;
390+
391+
assert(op->custom_options_format() == CustomOptionsFormat::CustomOptionsFormat_FLEXBUFFERS &&
392+
"Unsupported custom operation options format");
393+
394+
auto *op_code = _domain_model->operator_codes()->Get(op->opcode_index());
395+
auto custom_op_name = op_code->custom_code()->str();
396+
397+
enum class BuiltinOP
398+
{
399+
BCQUnembedding,
400+
};
401+
402+
// Mapping from custom op name string to BuiltinOP enum
403+
std::map<std::string, BuiltinOP> builtin_map = {
404+
{"BCQUnembedding", BuiltinOP::BCQUnembedding},
405+
};
406+
407+
// If unknown circle custom op, pass to BaseLoader
408+
if (builtin_map.find(custom_op_name) == builtin_map.end())
409+
{
410+
BaseLoader::loadOperation(op, subg);
411+
return;
412+
}
413+
414+
auto custom_op_id = builtin_map.at(custom_op_name);
415+
switch (custom_op_id)
416+
{
417+
case BuiltinOP::BCQUnembedding:
418+
loadBCQUnembedding(op, subg);
419+
break;
420+
default:
421+
throw std::runtime_error{"CircleLoader: Circle Custom OP map is defined but operation loader "
422+
"function is not defined"};
423+
}
424+
425+
return;
426+
}
427+
353428
} // namespace
354429

355430
std::unique_ptr<ir::Model> loadCircleModel(const std::string &filename)

0 commit comments

Comments
 (0)