3535import torch .nn as nn
3636import torch .nn .functional as F
3737
38- from torch . cuda import amp
38+ from torch import amp
3939
4040from compressai .entropy_models import EntropyBottleneck , GaussianConditional
4141from compressai .layers import QReLU
@@ -350,7 +350,6 @@ def gaussian_volume(x, sigma: float, num_levels: int):
350350 volume .append (interp .unsqueeze (2 ))
351351 return torch .cat (volume , dim = 2 )
352352
353- @amp .autocast (enabled = False )
354353 def warp_volume (self , volume , flow , scale_field , padding_mode : str = "border" ):
355354 """3D volume warping."""
356355 if volume .ndimension () != 5 :
@@ -360,14 +359,18 @@ def warp_volume(self, volume, flow, scale_field, padding_mode: str = "border"):
360359
361360 N , C , _ , H , W = volume .size ()
362361
363- grid = meshgrid2d (N , C , H , W , volume .device )
364- update_grid = grid + flow .permute (0 , 2 , 3 , 1 ).float ()
365- update_scale = scale_field .permute (0 , 2 , 3 , 1 ).float ()
366- volume_grid = torch .cat ((update_grid , update_scale ), dim = - 1 ).unsqueeze (1 )
367-
368- out = F .grid_sample (
369- volume .float (), volume_grid , padding_mode = padding_mode , align_corners = False
370- )
362+ with amp .autocast (device_type = volume .device .type , enabled = False ):
363+ grid = meshgrid2d (N , C , H , W , volume .device )
364+ update_grid = grid + flow .permute (0 , 2 , 3 , 1 ).float ()
365+ update_scale = scale_field .permute (0 , 2 , 3 , 1 ).float ()
366+ volume_grid = torch .cat ((update_grid , update_scale ), dim = - 1 ).unsqueeze (1 )
367+
368+ out = F .grid_sample (
369+ volume .float (),
370+ volume_grid ,
371+ padding_mode = padding_mode ,
372+ align_corners = False ,
373+ )
371374 return out .squeeze (2 )
372375
373376 def forward_prediction (self , x_ref , motion_info ):
0 commit comments