Skip to content

Commit 4117aa5

Browse files
authored
[Fix] Fix ascend faster-rcnn (#1842)
* Fix ascend faster-rcnn * remove static config
1 parent c4c1f52 commit 4117aa5

File tree

5 files changed

+6
-8
lines changed

5 files changed

+6
-8
lines changed

configs/mmdet/detection/detection_ascend_static-800x1344.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

mmdeploy/backend/ascend/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def forward(self, inputs: Dict[str,
394394

395395
for binding in self._model_desc.outputs:
396396
self._copy_buffer_to_tensor(
397-
self._output.buffers[binding.index], tensor)
397+
self._output.buffers[binding.index], outputs[binding.name])
398398

399399
return outputs
400400

mmdeploy/codebase/mmdet/deploy/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ def __gather_topk__trt(*inputs: Sequence[torch.Tensor],
249249
return outputs
250250

251251

252+
@FUNCTION_REWRITER.register_rewriter(
253+
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
254+
backend=Backend.ASCEND.value)
252255
@FUNCTION_REWRITER.register_rewriter(
253256
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
254257
backend=Backend.COREML.value)

mmdeploy/codebase/mmdet/models/roi_heads/standard_roi_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def standard_roi_head__predict_bbox(self,
3838
(num_instances, ).
3939
"""
4040
rois = rpn_results_list[0]
41-
rois_dims = rois.shape[-1]
41+
rois_dims = int(rois.shape[-1])
4242
batch_index = torch.arange(
4343
rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(
4444
rois.size(0), rois.size(1), 1)

mmdeploy/mmcv/ops/nms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,4 +713,4 @@ def multiclass_nms__ascend(boxes: Tensor,
713713
boxes, scores, score_threshold, iou_threshold, keep_top_k, keep_top_k)
714714

715715
dets = torch.cat([nmsed_boxes, nmsed_scores.unsqueeze(2)], dim=-1)
716-
return dets, nmsed_classes
716+
return dets, nmsed_classes.int()

0 commit comments

Comments
 (0)