Skip to content

Commit ff22707

Browse files
committed
refactor: extract stitching phase out of compiler.cpp
Signed-off-by: Bo Wang <[email protected]>
1 parent a165811 commit ff22707

File tree

8 files changed

+274
-240
lines changed

8 files changed

+274
-240
lines changed

core/compiler.cpp

Lines changed: 36 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
#include "torch/csrc/jit/frontend/function_schema_parser.h"
1313
#include "torch/csrc/jit/ir/ir.h"
14-
#include "torch/csrc/jit/ir/ir_views.h"
1514
#include "torch/csrc/jit/passes/graph_fuser.h"
1615
#include "torch/csrc/jit/passes/loop_unrolling.h"
1716
#include "torch/csrc/jit/passes/lower_graph.h"
@@ -128,193 +127,54 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
128127
return conversion::VerifyConverterSupportForBlock(g->block());
129128
}
130129

131-
void AddSegmentedBlockToGraph(
132-
std::shared_ptr<torch::jit::Graph>& g,
133-
partitioning::SegmentedBlock& seg,
134-
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
135-
// old_to_new_g contains: original global graph value => new global graph value,
136-
// mini_to_new_g: mini graph value -> new graph value
137-
std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g;
138-
size_t input_idx = 0;
139-
if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) {
140-
if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
141-
auto self = g->insertInput(0, "self_1");
142-
self->setType(seg.inputs()[0]->type());
143-
}
144-
mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0];
145-
}
146-
147-
for (auto& raw_input : seg.raw_inputs()) {
148-
if (old_to_new_g.count(raw_input)) {
149-
mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input];
150-
}
151-
}
152-
153-
for (const auto n : seg.nodes()) {
154-
util::cloneNode(n, g, mini_to_new_g);
155-
}
156-
157-
// original graph value => new global graph value
158-
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
159-
old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
160-
}
161-
size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0;
162-
for (size_t i = 0; i < seg.raw_inputs().size(); ++i) {
163-
if (!old_to_new_g.count(seg.raw_inputs()[i])) {
164-
old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]];
165-
}
166-
}
167-
168-
return;
169-
}
170-
171-
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
172-
GraphAndMapping;
173-
174-
void AddIfBlockToGraph(
175-
std::shared_ptr<torch::jit::Graph>& new_g,
176-
torch::jit::Node* if_node,
177-
const std::vector<GraphAndMapping>& graph_and_mappings,
178-
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
179-
torch::jit::IfView if_view(if_node);
180-
181-
// create a new if node in new_g and add corresponding inputs
182-
auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0));
183-
new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g));
184-
185-
// iterate over all blocks and add them to new created prim::If
186-
for (auto graph_and_mapping : graph_and_mappings) {
187-
auto new_if_block = new_if->addBlock();
188-
auto cur_block_graph = graph_and_mapping.first;
189-
auto cur_block_mapping = graph_and_mapping.second;
190-
std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
191-
for (auto& i : cur_block_mapping) {
192-
// for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
193-
// it's mini graph's input
194-
if (old_to_new_g.count(i.first)) {
195-
block_graph_to_new_g[i.second] = old_to_new_g[i.first];
196-
}
197-
}
198-
199-
auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); };
200-
new_if_block->cloneFrom(cur_block_graph->block(), env);
201-
if (cur_block_graph->inputs().size() &&
202-
cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
203-
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
204-
auto self = new_g->insertInput(0, "self_1");
205-
self->setType(cur_block_graph->inputs()[0]->type());
206-
}
207-
block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0];
208-
}
209-
for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) {
210-
new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]);
211-
new_if_block->eraseInput(i);
212-
}
213-
}
214-
for (auto ov : if_view.outputs()) {
215-
auto no = new_if->addOutput();
216-
old_to_new_g[ov] = no;
217-
no->copyMetadata(ov);
218-
}
219-
return;
220-
}
221-
222-
GraphAndMapping ConstructFallbackGraph_(
130+
partitioning::GraphAndMapping BuildHybridGraph(
223131
torch::jit::script::Module& new_mod,
224132
torch::jit::Block* block,
225-
partitioning::PartitioningCtx* partitioning_ctx,
226-
conversion::ConversionInfo convert_info,
133+
CompileSpec cfg,
227134
ir::StaticParams static_params,
228-
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map) {
229-
auto new_g = std::make_shared<torch::jit::Graph>();
135+
ir::CollectionTypeMap first_use_types) {
136+
auto convert_info = cfg.convert_info;
137+
auto partitioning_info = cfg.partitioning_info;
138+
139+
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
140+
auto collection_input_ivalues_map =
141+
partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
230142

231-
auto segmented_blocks = partitioning::Partition(partitioning_ctx, block, example_tensor_map);
143+
partitioning::Partition(&partitioning_ctx, collection_input_ivalues_map);
232144

233-
// the mapping from lowering graph => fallback global graph
234-
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
235-
for (auto input : block->inputs()) {
236-
util::getOrAddInputForValue(input, new_g, old_to_new_g);
237-
}
145+
for (auto &partitioned_block : partitioning_ctx.partitioned_blocks) {
146+
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
238147

239-
for (auto& seg_block : segmented_blocks) {
240-
LOG_INFO("Block segment:" << seg_block);
241-
std::ostringstream trt_engine_id;
242-
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
243-
244-
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
245-
auto shapes = seg_block.in_shapes();
246-
auto types = seg_block.in_types();
247-
std::vector<ir::Input> inputs;
248-
for (size_t i = 0; i < shapes.size(); i++) {
249-
auto in = ir::Input(shapes[i]);
250-
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
251-
inputs.push_back(in);
252-
}
253-
// update the input ranges for each segments
254-
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
255-
256-
// TODO mapping Inputs Ivalue to flatten one here
257-
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params);
258-
auto temp_g = std::make_shared<torch::jit::Graph>();
259-
auto device_spec = convert_info.engine_settings.device;
260-
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
261-
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
262-
263-
seg_block.update_graph(temp_g);
264-
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
265-
} else {
266-
if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
267-
auto if_node = seg_block.raw_nodes()[0];
268-
269-
// convert the 2 blocks in prim::if and get the converted graph with mappings
270-
std::vector<GraphAndMapping> graph_and_mappings;
271-
for (auto cur_block : if_node->blocks()) {
272-
graph_and_mappings.push_back(ConstructFallbackGraph_(
273-
new_mod, cur_block, partitioning_ctx, convert_info, static_params, example_tensor_map));
148+
for (auto& seg_block : segmented_blocks) {
149+
LOG_INFO("Block segment:" << seg_block);
150+
std::ostringstream trt_engine_id;
151+
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
152+
153+
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
154+
auto shapes = seg_block.in_shapes();
155+
auto types = seg_block.in_types();
156+
std::vector<ir::Input> inputs;
157+
for (size_t i = 0; i < shapes.size(); i++) {
158+
auto in = ir::Input(shapes[i]);
159+
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
160+
inputs.push_back(in);
274161
}
275-
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
162+
// update the input ranges for each segments
163+
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
276164

277-
} else {
278-
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
279-
}
280-
}
281-
}
165+
// TODO mapping Inputs Ivalue to flatten one here
166+
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params);
167+
auto temp_g = std::make_shared<torch::jit::Graph>();
168+
auto device_spec = convert_info.engine_settings.device;
169+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
170+
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
282171

283-
if (block->outputs().size() > 1) {
284-
std::vector<torch::jit::Value*> fallback_graph_vector;
285-
for (auto& output : block->outputs()) {
286-
if (old_to_new_g.count(output)) {
287-
fallback_graph_vector.push_back(old_to_new_g[output]);
172+
seg_block.update_graph(temp_g);
288173
}
289174
}
290-
torch::jit::ArrayRef<torch::jit::Value*> fallback_graph_outputs(fallback_graph_vector);
291-
auto return_tuple_node = new_g->createTuple(fallback_graph_outputs);
292-
new_g->block()->appendNode(return_tuple_node);
293-
// Set the output as the produced tuple
294-
new_g->registerOutput(return_tuple_node->outputs()[0]);
295-
} else {
296-
if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) {
297-
new_g->registerOutput(old_to_new_g[block->outputs()[0]]);
298-
}
299175
}
300-
return {new_g, old_to_new_g};
301-
}
302-
303-
GraphAndMapping ConstructFallbackGraph(
304-
torch::jit::script::Module& new_mod,
305-
torch::jit::Block* block,
306-
CompileSpec cfg,
307-
ir::StaticParams static_params,
308-
ir::CollectionTypeMap first_use_types) {
309-
auto convert_info = cfg.convert_info;
310-
auto partitioning_info = cfg.partitioning_info;
311-
312-
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
313-
auto collection_input_ivalues_map =
314-
partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);
315176

316-
return ConstructFallbackGraph_(
317-
new_mod, block, &partitioning_ctx, convert_info, static_params, collection_input_ivalues_map);
177+
return partitioning::Stitch(&partitioning_ctx, block);
318178
}
319179

320180
void MapInputsAndDetermineDTypes(
@@ -451,7 +311,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
451311
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
452312
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
453313
outputIsCollection)) {
454-
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), cfg, static_params, first_use_types);
314+
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
455315
new_g = graph_and_mapping.first;
456316
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
457317
for (size_t i = 0; i < new_g->inputs().size(); ++i) {

core/partitioning/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
srcs = [
1616
"partitioning.cpp",
1717
"shape_analysis.cpp",
18+
"stitching.cpp"
1819
],
1920
hdrs = [
2021
"partitioning.h",

0 commit comments

Comments
 (0)