Skip to content

Commit 77b91f3

Browse files
authored
fix: remove deprecated pynvml, remove torchmetrics restrictions (#566)
* fix: remove deprecated pynvml, remove torchmetrics restrictions
1 parent ab078c5 commit 77b91f3

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ classifiers = [
113113
dependencies = [
114114
"torch>=2.7.0",
115115
"torchvision>=0.22.0",
116-
"torchmetrics[image]==1.7.4",
116+
"torchmetrics[image]",
117117
"requests>=2.31.0",
118118
"transformers<5.0.0",
119119
"pytorch-lightning",
@@ -134,7 +134,7 @@ dependencies = [
134134
"opentelemetry-sdk>=1.30.0",
135135
"opentelemetry-exporter-otlp>=1.29.0",
136136
"codecarbon",
137-
"pynvml",
137+
"nvidia-ml-py",
138138
"thop",
139139
"timm",
140140
"bitsandbytes; sys_platform != 'darwin' or platform_machine != 'arm64'",

src/pruna/evaluation/metrics/metric_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ class TorchMetrics(Enum):
179179
clip_score = (partial(CLIPScore), None, "y_x")
180180
precision = (partial(Precision), None, "y_gt")
181181
recall = (partial(Recall), None, "y_gt")
182-
psnr = (partial(PeakSignalNoiseRatio), None, "pairwise_y_gt")
182+
psnr = (partial(PeakSignalNoiseRatio, data_range=255.0), None, "pairwise_y_gt")
183183
ssim = (partial(StructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt")
184184
msssim = (partial(MultiScaleStructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt")
185185
lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt")

0 commit comments

Comments
 (0)