Skip to content

Commit f4342a3

Browse files
committed
New time-fused LSTM API. Performance improvements.
This change adds a new LSTM API that fuses operations across RNN time steps. It performs significantly faster than the existing iterative API. The existing iterative LSTM API also received performance improvements. BREAKING CHANGES: Previously, callers were expected to transpose `h` before passing it to `BackwardPass`. Callers must not transpose `h` anymore. The `dv` parameter in `BackwardPass` has been removed and `v` must now be mutable.
1 parent 0a0222e commit f4342a3

File tree

7 files changed

+611
-228
lines changed

7 files changed

+611
-228
lines changed

examples/device_ptr.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ template<typename T>
2323
struct device_ptr {
2424
static constexpr size_t ElemSize = sizeof(typename T::Scalar);
2525

26+
static device_ptr<T> NewByteSized(size_t bytes) {
27+
return device_ptr<T>((bytes + ElemSize - 1) / ElemSize);
28+
}
29+
2630
explicit device_ptr(size_t size_)
2731
: data(nullptr), size(size_) {
2832
void* tmp;
@@ -38,6 +42,24 @@ struct device_ptr {
3842
ToDevice(elem);
3943
}
4044

45+
device_ptr(device_ptr<T>&& other) : data(other.data), size(other.size) {
46+
other.data = nullptr;
47+
other.size = 0;
48+
}
49+
50+
device_ptr& operator=(const device_ptr<T>&& other) {
51+
if (&other != this) {
52+
data = other.data;
53+
size = other.size;
54+
other.data = nullptr;
55+
other.size = 0;
56+
}
57+
return *this;
58+
}
59+
60+
device_ptr(const device_ptr<T>& other) = delete;
61+
device_ptr& operator=(const device_ptr<T>& other) = delete;
62+
4163
void ToDevice(const T& src) {
4264
assert(size == src.size());
4365
cudaMemcpy(data, src.data(), src.size() * ElemSize, cudaMemcpyHostToDevice);
@@ -48,6 +70,10 @@ struct device_ptr {
4870
cudaMemcpy(target.data(), data, target.size() * ElemSize, cudaMemcpyDeviceToHost);
4971
}
5072

73+
size_t Size() const {
74+
return size;
75+
}
76+
5177
void zero() {
5278
cudaMemset(data, 0, size * ElemSize);
5379
}
@@ -57,5 +83,5 @@ struct device_ptr {
5783
}
5884

5985
typename T::Scalar* data;
60-
const size_t size;
86+
size_t size;
6187
};

0 commit comments

Comments
 (0)