Skip to content

Commit 597d8e4

Browse files
authored
Add aten::record_stream (#1047)
- [x] record_stream
1 parent fda86cf commit 597d8e4

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
3+
#include <c10/xpu/XPUCachingAllocator.h>
4+
5+
#ifndef AT_PER_OPERATOR_HEADERS
6+
#include <ATen/NativeFunctions.h>
7+
#else
8+
#include <ATen/ops/record_stream_native.h>
9+
#endif
10+
11+
namespace at::native {
12+
void record_stream_xpu(Tensor& self, c10::Stream stream) {
13+
struct c10::StreamData3 data = stream.pack3();
14+
c10::xpu::XPUCachingAllocator::recordStream(self.storage().data_ptr(), at::xpu::XPUStream::unpack3(data.stream_id, data.device_index, data.device_type));
15+
}
16+
} // namespace at::native

yaml/native/native_functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7164,6 +7164,11 @@
71647164
- func: index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor
71657165
variants: function, method
71667166

7167+
- func: record_stream(Tensor(a!) self, Stream s) -> ()
7168+
variants: method
7169+
dispatch:
7170+
XPU: record_stream_xpu
7171+
71677172
- func: i0(Tensor self) -> Tensor
71687173
structured_delegate: i0.out
71697174
variants: function, method

0 commit comments

Comments
 (0)