Skip to content

Commit a5315f0

Browse files
authored
[onert] Support multiple Trix models in Bulk operation (#16328)
This commit revises TrixLoader and Bulk op to support multiple models. ONE-DCO-1.0-Signed-off-by: Jonghwa Lee <jonghwa3.lee@samsung.com>
1 parent e427b7a commit a5315f0

File tree

5 files changed

+79
-87
lines changed

5 files changed

+79
-87
lines changed

runtime/onert/backend/trix/KernelGenerator.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,18 @@ void KernelGenerator::visit(const ir::operation::Bulk &node)
6666
// parameters
6767
const auto &binary_path = node.param().binary_path;
6868

69-
auto fn = std::make_unique<ops::BulkLayer>();
70-
71-
fn->configure(input_tensors, output_tensors, binary_path, _dev_context);
72-
73-
_return_fn = std::move(fn);
69+
if (binary_path.size() == 1)
70+
{
71+
// For single model execution
72+
auto fn = std::make_unique<ops::BulkLayer>();
73+
fn->configure(input_tensors, output_tensors, binary_path.front(), _dev_context);
74+
_return_fn = std::move(fn);
75+
}
76+
else
77+
{
78+
// TODO: Implement multiple model execution
79+
throw std::runtime_error("NYI: multiple model execution");
80+
}
7481
}
7582

7683
} // namespace onert::backend::trix

runtime/onert/core/include/ir/operation/Bulk.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class Bulk : public Operation
2727
public:
2828
struct Param
2929
{
30-
std::string binary_path;
30+
std::vector<std::string> binary_path;
3131
std::vector<ir::Shape> origin_input_shapes;
3232
std::vector<ir::Shape> origin_output_shapes;
3333
};

runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ operation::BroadcastTo generateBroadcastTo()
8585
operation::Bulk generateBulk()
8686
{
8787
operation::Bulk::Param param;
88-
param.binary_path = "";
88+
param.binary_path = {};
8989
param.origin_input_shapes = std::vector<onert::ir::Shape>{};
9090
param.origin_output_shapes = std::vector<onert::ir::Shape>{};
9191

runtime/onert/core/src/loader/CircleLoader.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,9 @@ void CircleLoader::loadRunModel(const Operator *op, ir::Graph &subg)
337337
auto model_base_path = std::filesystem::path(_file_path).parent_path();
338338
auto *options = op->builtin_options_as_RunModelOptions();
339339
auto location = options->location()->str();
340-
auto model_path = model_base_path / location;
340+
// Multiple files can be specified as ';' separated string
341+
auto model_path = (location.find(';') == std::string::npos) ? model_base_path / location
342+
: std::filesystem::path(location);
341343
auto extension_path = model_path.extension();
342344
if (extension_path.empty())
343345
{

runtime/onert/loader/trix/TrixLoader.cc

Lines changed: 62 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,16 @@ class TrixLoader : public onert::loader::ILoader
9393
*/
9494
void loadModel(std::unique_ptr<ir::Model> &model);
9595
std::unique_ptr<ir::Graph> loadSubgraph();
96-
void loadOperands(ir::Graph &subg);
97-
ir::OperandIndex loadOperandFromInput(uint32_t i, ir::Graph &subg);
98-
ir::OperandIndex loadOperandFromOutput(uint32_t i, ir::Graph &subg);
9996
void loadBulk(ir::Graph &subg);
100-
void loadOperationIO(ir::OperandIndexSequence &inputs, ir::OperandIndexSequence &outputs);
10197
ir::OperandIndex inputIdxToOperandIdx(uint32_t i) const;
10298
ir::OperandIndex outputIdxToOperandIdx(uint32_t i) const;
10399
ir::DataType toDataType(const data_type type) const;
100+
void loadInputOperands(TrixMetaReader &meta, ir::Graph &subg);
101+
void loadOutputOperands(TrixMetaReader &meta, ir::Graph &subg);
104102

105103
private:
106104
/** path to model (e.g. tvn) */
107-
std::string _model_path;
105+
std::vector<std::string> _model_path;
108106
/** original IO shapes */
109107
std::vector<ir::Shape> _origin_input_shapes;
110108
std::vector<ir::Shape> _origin_output_shapes;
@@ -130,102 +128,80 @@ ir::OperandIndex TrixLoader::outputIdxToOperandIdx(uint32_t i) const
130128
return ir::OperandIndex(_meta.input_seg_num() + i);
131129
}
132130

133-
void TrixLoader::loadOperationIO(ir::OperandIndexSequence &inputs,
134-
ir::OperandIndexSequence &outputs)
135-
{
136-
for (uint32_t i = 0; i < _meta.input_seg_num(); ++i)
137-
{
138-
inputs.append(inputIdxToOperandIdx(i));
139-
}
140-
141-
for (uint32_t i = 0; i < _meta.output_seg_num(); ++i)
142-
{
143-
outputs.append(outputIdxToOperandIdx(i));
144-
}
145-
}
146-
147131
void TrixLoader::loadBulk(ir::Graph &subg)
148132
{
149133
ir::operation::Bulk::Param param;
150134
param.binary_path = _model_path;
151135
param.origin_input_shapes = _origin_input_shapes;
152136
param.origin_output_shapes = _origin_output_shapes;
153137

154-
ir::OperandIndexSequence inputs;
155-
ir::OperandIndexSequence outputs;
156-
157-
loadOperationIO(inputs, outputs);
158-
159-
std::unique_ptr<ir::operation::Bulk> bulk(new ir::operation::Bulk(inputs, outputs, param));
138+
std::unique_ptr<ir::operation::Bulk> bulk(
139+
new ir::operation::Bulk(subg.getInputs(), subg.getOutputs(), param));
160140
subg.addOperation(std::move(bulk));
161141
}
162142

163-
ir::OperandIndex TrixLoader::loadOperandFromInput(uint32_t idx, ir::Graph &subg)
164-
{
165-
// Shape
166-
ir::Shape shape;
167-
for (uint32_t d = 0; d < MAX_RANK; ++d)
168-
shape.append(_meta.input_seg_dims(idx, d));
169-
170-
// TypeInfo
171-
ir::TypeInfo type_info(toDataType(_meta.input_seg_quant_type(idx)),
172-
_meta.input_seg_quant_scale(idx), _meta.input_seg_quant_zp(idx));
173-
174-
_origin_input_shapes.push_back(shape);
175-
// Create operand
176-
const auto operand_index = subg.addOperand(shape, type_info);
177-
return operand_index;
178-
}
179-
180-
ir::OperandIndex TrixLoader::loadOperandFromOutput(uint32_t idx, ir::Graph &subg)
143+
void TrixLoader::loadInputOperands(TrixMetaReader &meta, ir::Graph &subg)
181144
{
182-
// Shape
183-
ir::Shape shape;
184-
for (uint32_t d = 0; d < MAX_RANK; ++d)
185-
shape.append(_meta.output_seg_dims(idx, d));
186-
187-
// TypeInfo
188-
ir::TypeInfo type_info(toDataType(_meta.output_seg_quant_type(idx)),
189-
_meta.output_seg_quant_scale(idx), _meta.output_seg_quant_zp(idx));
190-
191-
_origin_output_shapes.push_back(shape);
192-
// Create operand
193-
const auto operand_index = subg.addOperand(shape, type_info);
194-
return operand_index;
145+
for (uint32_t i = 0; i < meta.input_seg_num(); ++i)
146+
{
147+
// Shape
148+
ir::Shape shape;
149+
for (uint32_t d = 0; d < MAX_RANK; ++d)
150+
shape.append(meta.input_seg_dims(i, d));
151+
152+
// TypeInfo
153+
ir::TypeInfo type_info(toDataType(meta.input_seg_quant_type(i)), meta.input_seg_quant_scale(i),
154+
meta.input_seg_quant_zp(i));
155+
156+
_origin_input_shapes.push_back(shape);
157+
// Create operand
158+
subg.addOperand(shape, type_info);
159+
subg.addInput(ir::OperandIndex(i), "tvn_input" + std::to_string(i));
160+
}
195161
}
196162

197-
void TrixLoader::loadOperands(ir::Graph &subg)
163+
void TrixLoader::loadOutputOperands(TrixMetaReader &meta, ir::Graph &subg)
198164
{
199-
auto in_num = _meta.input_seg_num();
200-
for (uint32_t i = 0; i < in_num; ++i)
165+
for (uint32_t i = 0; i < meta.output_seg_num(); ++i)
201166
{
202-
loadOperandFromInput(i, subg);
203-
}
204-
auto out_num = _meta.output_seg_num();
205-
for (uint32_t i = 0; i < out_num; ++i)
206-
{
207-
loadOperandFromOutput(i, subg);
167+
// Shape
168+
ir::Shape shape;
169+
for (uint32_t d = 0; d < MAX_RANK; ++d)
170+
shape.append(meta.output_seg_dims(i, d));
171+
172+
// TypeInfo
173+
ir::TypeInfo type_info(toDataType(meta.output_seg_quant_type(i)),
174+
meta.output_seg_quant_scale(i), meta.output_seg_quant_zp(i));
175+
176+
_origin_output_shapes.push_back(shape);
177+
// Create operand
178+
subg.addOperand(shape, type_info);
179+
subg.addOutput(ir::OperandIndex(subg.getInputs().size() + i), "tvn_out" + std::to_string(i));
208180
}
209181
}
210182

211183
std::unique_ptr<ir::Graph> TrixLoader::loadSubgraph()
212184
{
213185
auto subg = std::make_unique<ir::Graph>();
214-
_meta.init(_model_path.c_str());
215-
216-
// Load tensors
217-
loadOperands(*subg);
218186

219-
// Set inputs
220-
for (uint32_t i = 0; i < _meta.input_seg_num(); ++i)
187+
if (_model_path.size() == 1)
221188
{
222-
subg->addInput(inputIdxToOperandIdx(i), "tvn_input" + std::to_string(i));
189+
// Single model
190+
TrixMetaReader meta;
191+
meta.init(_model_path.front().c_str());
192+
loadInputOperands(meta, *subg);
193+
loadOutputOperands(meta, *subg);
223194
}
224-
// Set outputs
225-
for (uint32_t i = 0; i < _meta.output_seg_num(); ++i)
195+
else
226196
{
227-
subg->addOutput(outputIdxToOperandIdx(i), "tvn_out" + std::to_string(i));
197+
// Multiple models
198+
TrixMetaReader head_model_meta, tail_model_meta;
199+
head_model_meta.init(_model_path.front().c_str());
200+
tail_model_meta.init(_model_path.back().c_str());
201+
loadInputOperands(head_model_meta, *subg);
202+
loadOutputOperands(tail_model_meta, *subg);
228203
}
204+
229205
// Create operations
230206
loadBulk(*subg);
231207

@@ -243,11 +219,18 @@ void TrixLoader::loadModel(std::unique_ptr<ir::Model> &model)
243219
std::unique_ptr<ir::Model> TrixLoader::loadFromFile(const std::string &file_path)
244220
{
245221
auto model = std::make_unique<ir::Model>();
246-
// model path will be used to set Bulk param
247-
_model_path = file_path;
248-
// metadata is initialized from model path since it is loadFromFile
249-
_meta.init(_model_path.c_str());
222+
223+
std::stringstream ss{file_path};
224+
std::string path;
225+
while (std::getline(ss, path, ';'))
226+
{
227+
// model path will be used to set Bulk param
228+
// , and it can be more than one for multiple models
229+
_model_path.push_back(path);
230+
}
231+
250232
loadModel(model);
233+
251234
return model;
252235
}
253236

0 commit comments

Comments
 (0)