Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit ca3d56f

Browse files
piiswrongcjolivier01
authored andcommitted
fix place device (#8450)
1 parent c830555 commit ca3d56f

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/imperative/imperative_utils.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,15 @@ inline std::vector<Context> PlaceDevice(const nnvm::IndexedGraph& idx) {
571571
auto fwd_nid = idx.node_id(idx[i].source->control_deps[0].get());
572572
CHECK_EQ(idx[fwd_nid].source->op(), _copyto);
573573
vctx[i] = vctx[idx[fwd_nid].inputs[0].node_id];
574-
} else if (idx[i].inputs.size() && vctx[i].dev_type == -1) {
575-
vctx[i] = vctx[idx[i].inputs[0].node_id];
574+
} else if (idx[i].control_deps.size() &&
575+
vctx[idx[i].control_deps[0]].dev_type != -1) {
576+
vctx[i] = vctx[idx[i].control_deps[0]];
577+
} else {
578+
for (const auto& in : idx[i].inputs) {
579+
if (vctx[in.node_id].dev_type == -1) continue;
580+
vctx[i] = vctx[in.node_id];
581+
break;
582+
}
576583
}
577584
}
578585
// backward pass

0 commit comments

Comments
 (0)