Skip to content

Commit d14208f

Browse files
authored
feat: add ifrt copy api (#1333)
1 parent 2a102fe commit d14208f

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2628,3 +2628,15 @@ extern "C" void addSdyPropagationPipeline(
26282628
0};
26292629
mlir::sdy::addPropagationPipeline(pm, options);
26302630
}
2631+
2632+
extern "C" HeldIfrtArray *ifrt_copy_array(HeldIfrtArray *array) {
2633+
auto pjrtArray = dyn_cast<ifrt::PjRtArray>(array->obj().get());
2634+
if (pjrtArray) {
2635+
std::optional<ifrt::DeviceListRef> devices;
2636+
std::optional<ifrt::MemoryKind> memory_kind;
2637+
auto res = MyValueOrThrow(pjrtArray->Copy(
2638+
devices, memory_kind, static_cast<ifrt::ArrayCopySemantics>(0)));
2639+
return reactant::capture(res);
2640+
}
2641+
ReactantThrowError("Only ifrt-pjrt arrays are supported for now");
2642+
}

0 commit comments

Comments
 (0)