Skip to content

Commit 12e7bcd

Browse files
authored
register meta func for rnn (#2159)
1 parent cfe2a9b commit 12e7bcd

File tree

1 file changed

+48
-11
lines changed

1 file changed

+48
-11
lines changed

colossalai/fx/_meta_registrations.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -200,19 +200,56 @@ def meta_adaptive_avg_pool2d_backward(
200200
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
201201
@register_meta(aten._cudnn_rnn.default)
202202
def meta_cuda_rnn(
203-
input: torch.Tensor,
204-
weight: torch.Tensor,
205-
weight_stride0: int,
206-
weight_buf: torch.Tensor,
207-
hx: torch.Tensor,
208-
cx: Optional[torch.Tensor] = None,
209-
*args,
210-
**kwargs,
203+
input,
204+
weight,
205+
weight_stride0,
206+
weight_buf,
207+
hx,
208+
cx,
209+
mode,
210+
hidden_size,
211+
proj_size,
212+
num_layers,
213+
batch_first,
214+
dropout,
215+
train,
216+
bidirectional,
217+
batch_sizes,
218+
dropout_state,
211219
):
212-
if cx is not None:
213-
return torch.empty_like(input), torch.empty_like(hx), torch.empty_like(cx)
220+
221+
is_input_packed = len(batch_sizes) != 0
222+
if is_input_packed:
223+
seq_length = len(batch_sizes)
224+
mini_batch = batch_sizes[0]
225+
batch_sizes_sum = input.shape[0]
214226
else:
215-
return torch.empty_like(input), torch.empty_like(hx), torch.empty((), device='meta')
227+
seq_length = input.shape[1] if batch_first else input.shape[0]
228+
mini_batch = input.shape[0] if batch_first else input.shape[1]
229+
batch_sizes_sum = -1
230+
231+
num_directions = 2 if bidirectional else 1
232+
out_size = proj_size if proj_size != 0 else hidden_size
233+
if is_input_packed:
234+
out_shape = [batch_sizes_sum, out_size * num_directions]
235+
else:
236+
out_shape = (
237+
[mini_batch, seq_length, out_size * num_directions]
238+
if batch_first
239+
else [seq_length, mini_batch, out_size * num_directions]
240+
)
241+
output = input.new_empty(out_shape)
242+
243+
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
244+
cy = torch.empty(0) if cx is None else cx.new_empty(cell_shape)
245+
246+
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
247+
248+
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
249+
reserve_shape = 0 if train else 0
250+
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
251+
252+
return output, hy, cy, reserve, weight_buf
216253

217254

218255
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp

0 commit comments

Comments
 (0)