@@ -200,19 +200,56 @@ def meta_adaptive_avg_pool2d_backward(
200200# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
201201@register_meta (aten ._cudnn_rnn .default )
202202def meta_cuda_rnn (
203- input : torch .Tensor ,
204- weight : torch .Tensor ,
205- weight_stride0 : int ,
206- weight_buf : torch .Tensor ,
207- hx : torch .Tensor ,
208- cx : Optional [torch .Tensor ] = None ,
209- * args ,
210- ** kwargs ,
203+ input ,
204+ weight ,
205+ weight_stride0 ,
206+ weight_buf ,
207+ hx ,
208+ cx ,
209+ mode ,
210+ hidden_size ,
211+ proj_size ,
212+ num_layers ,
213+ batch_first ,
214+ dropout ,
215+ train ,
216+ bidirectional ,
217+ batch_sizes ,
218+ dropout_state ,
211219):
212- if cx is not None :
213- return torch .empty_like (input ), torch .empty_like (hx ), torch .empty_like (cx )
220+
221+ is_input_packed = len (batch_sizes ) != 0
222+ if is_input_packed :
223+ seq_length = len (batch_sizes )
224+ mini_batch = batch_sizes [0 ]
225+ batch_sizes_sum = input .shape [0 ]
214226 else :
215- return torch .empty_like (input ), torch .empty_like (hx ), torch .empty ((), device = 'meta' )
227+ seq_length = input .shape [1 ] if batch_first else input .shape [0 ]
228+ mini_batch = input .shape [0 ] if batch_first else input .shape [1 ]
229+ batch_sizes_sum = - 1
230+
231+ num_directions = 2 if bidirectional else 1
232+ out_size = proj_size if proj_size != 0 else hidden_size
233+ if is_input_packed :
234+ out_shape = [batch_sizes_sum , out_size * num_directions ]
235+ else :
236+ out_shape = (
237+ [mini_batch , seq_length , out_size * num_directions ]
238+ if batch_first
239+ else [seq_length , mini_batch , out_size * num_directions ]
240+ )
241+ output = input .new_empty (out_shape )
242+
243+ cell_shape = [num_layers * num_directions , mini_batch , hidden_size ]
244+ cy = torch .empty (0 ) if cx is None else cx .new_empty (cell_shape )
245+
246+ hy = hx .new_empty ([num_layers * num_directions , mini_batch , out_size ])
247+
248+ # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
249+ reserve_shape = 0 if train else 0
250+ reserve = input .new_empty (reserve_shape , dtype = torch .uint8 )
251+
252+ return output , hy , cy , reserve , weight_buf
216253
217254
218255# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
0 commit comments