Skip to content

Commit d997a40

Browse files
committed
chores: update tested pytorch images and bump dependencies
(cherry picked from commit 3c8b7bc01f1401c8e0a08934a317e57d7299365b)
1 parent 94853d2 commit d997a40

File tree

6 files changed

+22
-17
lines changed

6 files changed

+22
-17
lines changed

.gitlab-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ test:
9696
parallel:
9797
matrix:
9898
- PYTORCH_IMAGE:
99-
- "1.8.1-cuda11.1-cudnn8-devel"
99+
- "1.13.1-cuda11.6-cudnn8-devel"
100100
- "2.2.0-cuda11.8-cudnn8-devel"
101101
tags:
102102
- docker

compressai/models/video/google.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import torch.nn as nn
3636
import torch.nn.functional as F
3737

38-
from torch.cuda import amp
38+
from torch import amp
3939

4040
from compressai.entropy_models import EntropyBottleneck, GaussianConditional
4141
from compressai.layers import QReLU
@@ -350,7 +350,6 @@ def gaussian_volume(x, sigma: float, num_levels: int):
350350
volume.append(interp.unsqueeze(2))
351351
return torch.cat(volume, dim=2)
352352

353-
@amp.autocast(enabled=False)
354353
def warp_volume(self, volume, flow, scale_field, padding_mode: str = "border"):
355354
"""3D volume warping."""
356355
if volume.ndimension() != 5:
@@ -360,14 +359,18 @@ def warp_volume(self, volume, flow, scale_field, padding_mode: str = "border"):
360359

361360
N, C, _, H, W = volume.size()
362361

363-
grid = meshgrid2d(N, C, H, W, volume.device)
364-
update_grid = grid + flow.permute(0, 2, 3, 1).float()
365-
update_scale = scale_field.permute(0, 2, 3, 1).float()
366-
volume_grid = torch.cat((update_grid, update_scale), dim=-1).unsqueeze(1)
367-
368-
out = F.grid_sample(
369-
volume.float(), volume_grid, padding_mode=padding_mode, align_corners=False
370-
)
362+
with amp.autocast(device_type=volume.device.type, enabled=False):
363+
grid = meshgrid2d(N, C, H, W, volume.device)
364+
update_grid = grid + flow.permute(0, 2, 3, 1).float()
365+
update_scale = scale_field.permute(0, 2, 3, 1).float()
366+
volume_grid = torch.cat((update_grid, update_scale), dim=-1).unsqueeze(1)
367+
368+
out = F.grid_sample(
369+
volume.float(),
370+
volume_grid,
371+
padding_mode=padding_mode,
372+
align_corners=False,
373+
)
371374
return out.squeeze(2)
372375

373376
def forward_prediction(self, x_ref, motion_info):

compressai/utils/video/eval_model/__main__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@
4343
import torch.nn.functional as F
4444

4545
from pytorch_msssim import ms_ssim
46-
from torch import Tensor
47-
from torch.cuda import amp
46+
from torch import Tensor, amp
4847
from torch.utils.model_zoo import tqdm
4948

5049
import compressai
@@ -370,7 +369,9 @@ def run_inference(
370369
if sequence_metrics_path.is_file():
371370
continue
372371

373-
with amp.autocast(enabled=args["half"]):
372+
with amp.autocast(
373+
device_type=next(net.parameters()).device.type, enabled=args["half"]
374+
):
374375
with torch.no_grad():
375376
if entropy_estimation:
376377
metrics = eval_model_entropy_estimation(net, filepath)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[build-system]
2-
requires = ["setuptools>=42", "wheel", "pybind11>=2.6.0", "torch"]
2+
requires = ["setuptools>=42", "wheel", "pybind11>=2.12", "torch"]
33
build-backend = "setuptools.build_meta"
44

55
[tool.black]

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,11 @@ def get_extra_requirements():
136136
python_requires=">=3.8",
137137
install_requires=[
138138
"einops",
139-
"numpy>=1.21.0",
139+
"numpy>=1.21.0, <2",
140140
"pandas",
141141
"scipy",
142142
"matplotlib",
143-
"torch>=1.7.1, <2.3",
143+
"torch>=1.13.1",
144144
"torch-geometric>=2.3.0",
145145
"typing-extensions>=4.0.0",
146146
"torchvision",

tests/test_entropy_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def test_compiling(self):
273273

274274
def test_update(self):
275275
# get a pretrained model
276+
276277
net = bmshj2018_factorized(quality=1, pretrained=True).eval()
277278
assert not net.update()
278279
assert not net.update(force=False)

0 commit comments

Comments
 (0)