|
11 | 11 |
|
12 | 12 | #include "torch/csrc/jit/frontend/function_schema_parser.h"
|
13 | 13 | #include "torch/csrc/jit/ir/ir.h"
|
14 |
| -#include "torch/csrc/jit/ir/ir_views.h" |
15 | 14 | #include "torch/csrc/jit/passes/graph_fuser.h"
|
16 | 15 | #include "torch/csrc/jit/passes/loop_unrolling.h"
|
17 | 16 | #include "torch/csrc/jit/passes/lower_graph.h"
|
@@ -128,193 +127,54 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
|
128 | 127 | return conversion::VerifyConverterSupportForBlock(g->block());
|
129 | 128 | }
|
130 | 129 |
|
131 |
| -void AddSegmentedBlockToGraph( |
132 |
| - std::shared_ptr<torch::jit::Graph>& g, |
133 |
| - partitioning::SegmentedBlock& seg, |
134 |
| - std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) { |
135 |
| - // old_to_new_g contains: original global graph value => new global graph value, |
136 |
| - // mini_to_new_g: mini graph value -> new graph value |
137 |
| - std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g; |
138 |
| - size_t input_idx = 0; |
139 |
| - if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) { |
140 |
| - if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { |
141 |
| - auto self = g->insertInput(0, "self_1"); |
142 |
| - self->setType(seg.inputs()[0]->type()); |
143 |
| - } |
144 |
| - mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0]; |
145 |
| - } |
146 |
| - |
147 |
| - for (auto& raw_input : seg.raw_inputs()) { |
148 |
| - if (old_to_new_g.count(raw_input)) { |
149 |
| - mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input]; |
150 |
| - } |
151 |
| - } |
152 |
| - |
153 |
| - for (const auto n : seg.nodes()) { |
154 |
| - util::cloneNode(n, g, mini_to_new_g); |
155 |
| - } |
156 |
| - |
157 |
| - // original graph value => new global graph value |
158 |
| - for (size_t i = 0; i < seg.raw_outputs().size(); ++i) { |
159 |
| - old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]]; |
160 |
| - } |
161 |
| - size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0; |
162 |
| - for (size_t i = 0; i < seg.raw_inputs().size(); ++i) { |
163 |
| - if (!old_to_new_g.count(seg.raw_inputs()[i])) { |
164 |
| - old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]]; |
165 |
| - } |
166 |
| - } |
167 |
| - |
168 |
| - return; |
169 |
| -} |
170 |
| - |
171 |
| -typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>> |
172 |
| - GraphAndMapping; |
173 |
| - |
174 |
| -void AddIfBlockToGraph( |
175 |
| - std::shared_ptr<torch::jit::Graph>& new_g, |
176 |
| - torch::jit::Node* if_node, |
177 |
| - const std::vector<GraphAndMapping>& graph_and_mappings, |
178 |
| - std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) { |
179 |
| - torch::jit::IfView if_view(if_node); |
180 |
| - |
181 |
| - // create a new if node in new_g and add corresponding inputs |
182 |
| - auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0)); |
183 |
| - new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g)); |
184 |
| - |
185 |
| - // iterate over all blocks and add them to new created prim::If |
186 |
| - for (auto graph_and_mapping : graph_and_mappings) { |
187 |
| - auto new_if_block = new_if->addBlock(); |
188 |
| - auto cur_block_graph = graph_and_mapping.first; |
189 |
| - auto cur_block_mapping = graph_and_mapping.second; |
190 |
| - std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g; |
191 |
| - for (auto& i : cur_block_mapping) { |
192 |
| - // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then |
193 |
| - // it's mini graph's input |
194 |
| - if (old_to_new_g.count(i.first)) { |
195 |
| - block_graph_to_new_g[i.second] = old_to_new_g[i.first]; |
196 |
| - } |
197 |
| - } |
198 |
| - |
199 |
| - auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); }; |
200 |
| - new_if_block->cloneFrom(cur_block_graph->block(), env); |
201 |
| - if (cur_block_graph->inputs().size() && |
202 |
| - cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { |
203 |
| - if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { |
204 |
| - auto self = new_g->insertInput(0, "self_1"); |
205 |
| - self->setType(cur_block_graph->inputs()[0]->type()); |
206 |
| - } |
207 |
| - block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0]; |
208 |
| - } |
209 |
| - for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) { |
210 |
| - new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]); |
211 |
| - new_if_block->eraseInput(i); |
212 |
| - } |
213 |
| - } |
214 |
| - for (auto ov : if_view.outputs()) { |
215 |
| - auto no = new_if->addOutput(); |
216 |
| - old_to_new_g[ov] = no; |
217 |
| - no->copyMetadata(ov); |
218 |
| - } |
219 |
| - return; |
220 |
| -} |
221 |
| - |
222 |
| -GraphAndMapping ConstructFallbackGraph_( |
| 130 | +partitioning::GraphAndMapping BuildHybridGraph( |
223 | 131 | torch::jit::script::Module& new_mod,
|
224 | 132 | torch::jit::Block* block,
|
225 |
| - partitioning::PartitioningCtx* partitioning_ctx, |
226 |
| - conversion::ConversionInfo convert_info, |
| 133 | + CompileSpec cfg, |
227 | 134 | ir::StaticParams static_params,
|
228 |
| - std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map) { |
229 |
| - auto new_g = std::make_shared<torch::jit::Graph>(); |
| 135 | + ir::CollectionTypeMap first_use_types) { |
| 136 | + auto convert_info = cfg.convert_info; |
| 137 | + auto partitioning_info = cfg.partitioning_info; |
| 138 | + |
| 139 | + auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info); |
| 140 | + auto collection_input_ivalues_map = |
| 141 | + partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types); |
230 | 142 |
|
231 |
| - auto segmented_blocks = partitioning::Partition(partitioning_ctx, block, example_tensor_map); |
| 143 | + partitioning::Partition(&partitioning_ctx, collection_input_ivalues_map); |
232 | 144 |
|
233 |
| - // the mapping from lowering graph => fallback global graph |
234 |
| - std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g; |
235 |
| - for (auto input : block->inputs()) { |
236 |
| - util::getOrAddInputForValue(input, new_g, old_to_new_g); |
237 |
| - } |
| 145 | + for (auto &partitioned_block : partitioning_ctx.partitioned_blocks) { |
| 146 | + partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second; |
238 | 147 |
|
239 |
| - for (auto& seg_block : segmented_blocks) { |
240 |
| - LOG_INFO("Block segment:" << seg_block); |
241 |
| - std::ostringstream trt_engine_id; |
242 |
| - trt_engine_id << reinterpret_cast<const int*>(&seg_block); |
243 |
| - |
244 |
| - if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { |
245 |
| - auto shapes = seg_block.in_shapes(); |
246 |
| - auto types = seg_block.in_types(); |
247 |
| - std::vector<ir::Input> inputs; |
248 |
| - for (size_t i = 0; i < shapes.size(); i++) { |
249 |
| - auto in = ir::Input(shapes[i]); |
250 |
| - in.dtype = util::ScalarTypeToTRTDataType(types[i]); |
251 |
| - inputs.push_back(in); |
252 |
| - } |
253 |
| - // update the input ranges for each segments |
254 |
| - convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params); |
255 |
| - |
256 |
| - // TODO mapping Inputs Ivalue to flatten one here |
257 |
| - auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params); |
258 |
| - auto temp_g = std::make_shared<torch::jit::Graph>(); |
259 |
| - auto device_spec = convert_info.engine_settings.device; |
260 |
| - auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); |
261 |
| - AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true); |
262 |
| - |
263 |
| - seg_block.update_graph(temp_g); |
264 |
| - AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); |
265 |
| - } else { |
266 |
| - if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) { |
267 |
| - auto if_node = seg_block.raw_nodes()[0]; |
268 |
| - |
269 |
| - // convert the 2 blocks in prim::if and get the converted graph with mappings |
270 |
| - std::vector<GraphAndMapping> graph_and_mappings; |
271 |
| - for (auto cur_block : if_node->blocks()) { |
272 |
| - graph_and_mappings.push_back(ConstructFallbackGraph_( |
273 |
| - new_mod, cur_block, partitioning_ctx, convert_info, static_params, example_tensor_map)); |
| 148 | + for (auto& seg_block : segmented_blocks) { |
| 149 | + LOG_INFO("Block segment:" << seg_block); |
| 150 | + std::ostringstream trt_engine_id; |
| 151 | + trt_engine_id << reinterpret_cast<const int*>(&seg_block); |
| 152 | + |
| 153 | + if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { |
| 154 | + auto shapes = seg_block.in_shapes(); |
| 155 | + auto types = seg_block.in_types(); |
| 156 | + std::vector<ir::Input> inputs; |
| 157 | + for (size_t i = 0; i < shapes.size(); i++) { |
| 158 | + auto in = ir::Input(shapes[i]); |
| 159 | + in.dtype = util::ScalarTypeToTRTDataType(types[i]); |
| 160 | + inputs.push_back(in); |
274 | 161 | }
|
275 |
| - AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g); |
| 162 | + // update the input ranges for each segments |
| 163 | + convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params); |
276 | 164 |
|
277 |
| - } else { |
278 |
| - AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); |
279 |
| - } |
280 |
| - } |
281 |
| - } |
| 165 | + // TODO mapping Inputs Ivalue to flatten one here |
| 166 | + auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params); |
| 167 | + auto temp_g = std::make_shared<torch::jit::Graph>(); |
| 168 | + auto device_spec = convert_info.engine_settings.device; |
| 169 | + auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); |
| 170 | + AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true); |
282 | 171 |
|
283 |
| - if (block->outputs().size() > 1) { |
284 |
| - std::vector<torch::jit::Value*> fallback_graph_vector; |
285 |
| - for (auto& output : block->outputs()) { |
286 |
| - if (old_to_new_g.count(output)) { |
287 |
| - fallback_graph_vector.push_back(old_to_new_g[output]); |
| 172 | + seg_block.update_graph(temp_g); |
288 | 173 | }
|
289 | 174 | }
|
290 |
| - torch::jit::ArrayRef<torch::jit::Value*> fallback_graph_outputs(fallback_graph_vector); |
291 |
| - auto return_tuple_node = new_g->createTuple(fallback_graph_outputs); |
292 |
| - new_g->block()->appendNode(return_tuple_node); |
293 |
| - // Set the output as the produced tuple |
294 |
| - new_g->registerOutput(return_tuple_node->outputs()[0]); |
295 |
| - } else { |
296 |
| - if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) { |
297 |
| - new_g->registerOutput(old_to_new_g[block->outputs()[0]]); |
298 |
| - } |
299 | 175 | }
|
300 |
| - return {new_g, old_to_new_g}; |
301 |
| -} |
302 |
| - |
303 |
| -GraphAndMapping ConstructFallbackGraph( |
304 |
| - torch::jit::script::Module& new_mod, |
305 |
| - torch::jit::Block* block, |
306 |
| - CompileSpec cfg, |
307 |
| - ir::StaticParams static_params, |
308 |
| - ir::CollectionTypeMap first_use_types) { |
309 |
| - auto convert_info = cfg.convert_info; |
310 |
| - auto partitioning_info = cfg.partitioning_info; |
311 |
| - |
312 |
| - auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info); |
313 |
| - auto collection_input_ivalues_map = |
314 |
| - partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types); |
315 | 176 |
|
316 |
| - return ConstructFallbackGraph_( |
317 |
| - new_mod, block, &partitioning_ctx, convert_info, static_params, collection_input_ivalues_map); |
| 177 | + return partitioning::Stitch(&partitioning_ctx, block); |
318 | 178 | }
|
319 | 179 |
|
320 | 180 | void MapInputsAndDetermineDTypes(
|
@@ -451,7 +311,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
|
451 | 311 | (!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
|
452 | 312 | cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
|
453 | 313 | outputIsCollection)) {
|
454 |
| - auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), cfg, static_params, first_use_types); |
| 314 | + auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types); |
455 | 315 | new_g = graph_and_mapping.first;
|
456 | 316 | // renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
|
457 | 317 | for (size_t i = 0; i < new_g->inputs().size(); ++i) {
|
|
0 commit comments