@@ -4329,134 +4329,6 @@ def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
43294329 )
43304330
43314331
4332- @register_meta (aten .max_unpool2d )
4333- @out_wrapper ()
4334- def meta_max_unpool2d (self , indices , output_size ):
4335- utils .alert_not_deterministic ("max_unpooling2d_forward_out" )
4336-
4337- torch ._check (
4338- indices .dtype == torch .int64 ,
4339- lambda : f"elements in indices should be type int64 but got: { indices .dtype } " ,
4340- )
4341- torch ._check (
4342- len (output_size ) == 2 ,
4343- lambda : (
4344- f"There should be exactly two elements (height, width) in output_size, "
4345- f"but got { len (output_size )} elements."
4346- ),
4347- )
4348-
4349- oheight , owidth = output_size
4350-
4351- torch ._check (
4352- self .ndim in (3 , 4 ),
4353- lambda : (
4354- f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
4355- f"but got a tensor with { self .ndim } dimensions."
4356- ),
4357- )
4358- torch ._check (
4359- self .shape == indices .shape ,
4360- lambda : (
4361- f"Expected shape of indices to be same as that of the input tensor ({ self .shape } ) "
4362- f"but got indices tensor with shape: { indices .shape } "
4363- ),
4364- )
4365-
4366- for i in range (1 , self .ndim ):
4367- torch ._check (
4368- self .size (i ) > 0 ,
4369- lambda : (
4370- f"max_unpooling2d(): "
4371- f"Expected input to have non-zero size for non-batch dimensions, "
4372- f"but got { self .shape } with dimension { i } being empty."
4373- ),
4374- )
4375-
4376- self = self .contiguous ()
4377-
4378- if self .ndim == 3 :
4379- nchannels = self .size (0 )
4380- result = self .new_empty ((nchannels , oheight , owidth ))
4381- else :
4382- nbatch = self .size (0 )
4383- nchannels = self .size (1 )
4384- result = self .new_empty ((nbatch , nchannels , oheight , owidth ))
4385-
4386- return result
4387-
4388-
4389- def _max_unpooling3d_shape_check (input , indices , output_size , stride , padding , fn_name ):
4390- torch ._check (
4391- indices .dtype == torch .int64 , lambda : "elements in indices should be type int64"
4392- )
4393- torch ._check (
4394- input .ndim in (4 , 5 ),
4395- lambda : f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with { input .ndim } dimensions." ,
4396- )
4397- torch ._check (
4398- len (output_size ) == 3 ,
4399- lambda : (
4400- f"There should be exactly three elements (depth, height, width) in output_size, "
4401- f"but got { len (output_size )} elements."
4402- ),
4403- )
4404- torch ._check (
4405- len (stride ) == 3 ,
4406- lambda : f"There should be exactly three elements (depth, height, width) in stride, but got: { len (stride )} elements." ,
4407- )
4408- torch ._check (
4409- len (padding ) == 3 ,
4410- lambda : f"There should be exactly three elements (depth, height, width) in padding, but got: { len (padding )} elements." ,
4411- )
4412- torch ._check (
4413- input .shape == indices .shape ,
4414- lambda : (
4415- f"Expected shape of indices to be same as that of the input tensor ({ input .shape } ) "
4416- f"but got indices tensor with shape: { indices .shape } "
4417- ),
4418- )
4419-
4420- for i in range (1 , input .ndim ):
4421- torch ._check (
4422- input .size (i ) > 0 ,
4423- lambda : (
4424- f"{ fn_name } : "
4425- f"Expected input to have non-zero size for non-batch dimensions, "
4426- f"but got { input .shape } with dimension { i } being empty."
4427- ),
4428- )
4429-
4430- torch ._check (
4431- stride [0 ] > 0 and stride [1 ] > 0 and stride [2 ] > 0 ,
4432- lambda : f"strides should be greater than zero, but got stride: { stride } " ,
4433- )
4434-
4435-
4436- @register_meta (aten .max_unpool3d )
4437- @out_wrapper ()
4438- def meta_max_unpool3d (self , indices , output_size , stride , padding ):
4439- utils .alert_not_deterministic ("max_unpooling3d_forward_out" )
4440-
4441- _max_unpooling3d_shape_check (
4442- self , indices , output_size , stride , padding , "max_unpooling3d()"
4443- )
4444-
4445- self = self .contiguous ()
4446-
4447- odepth , oheight , owidth = output_size
4448-
4449- if self .ndim == 4 :
4450- nchannels = self .size (0 )
4451- result = self .new_empty ((nchannels , odepth , oheight , owidth ))
4452- else :
4453- nbatch = self .size (0 )
4454- nchannels = self .size (1 )
4455- result = self .new_empty ((nbatch , nchannels , odepth , oheight , owidth ))
4456-
4457- return result
4458-
4459-
44604332@register_meta (aten .max_pool3d_with_indices )
44614333@out_wrapper ("out" , "indices" )
44624334def meta_max_pool3d_with_indices (
0 commit comments