Skip to content

Commit 9659199

Browse files
fix: add realesrgan as a dependency to fix basicsr missing problem (#569)
* fix: add realesrgan as a dependency to fix basicsr missing problem * ci: update gliner python dependency * feat: super duper new features * fix: remove feature time * fix: add protection around gliner import in torch structured --------- Co-authored-by: Gaspar Rochette <gaspar.rochette@pruna.ai>
1 parent 77b91f3 commit 9659199

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ dependencies = [
143143
"whisper-s2t==1.3.1",
144144
"hqq==0.2.7.post1",
145145
"torchao>=0.12.0,<0.16.0", # 0.16.0 breaks diffusers 0.36.0, torch+torch: https://github.com/pytorch/ao/issues/2919#issue-3375688762
146-
"gliner; python_version >= '3.10'",
146+
"gliner; python_version >= '3.11'",
147147
"piq",
148148
"opencv-python",
149149
"kernels",
@@ -153,6 +153,7 @@ dependencies = [
153153
"peft>=0.18.0",
154154
"trl<=0.21.0",
155155
"termcolor==2.3.0",
156+
"realesrgan"
156157
]
157158

158159
[project.optional-dependencies]

src/pruna/algorithms/torch_structured.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def model_check_fn(self, model: Any) -> bool:
149149
return True
150150
if isinstance(model, imported_modules["torchvision"].models.resnet.ResNet):
151151
return True
152-
if isinstance(model, imported_modules["GLiNER"]):
152+
if imported_modules["GLiNER"] is not None and isinstance(model, imported_modules["GLiNER"]):
153153
return True
154154
return isinstance(model, imported_modules["timm"].models.resnet.ResNet)
155155

@@ -257,7 +257,6 @@ def import_algorithm_packages(self) -> Dict[str, Any]:
257257
import timm
258258
import torch_pruning as tp
259259
import torchvision
260-
from gliner import GLiNER
261260
from timm.models.mvitv2 import MultiScaleAttention
262261
from timm.models.mvitv2 import MultiScaleVit as MViT
263262
from transformers.models.llama.modeling_llama import LlamaForCausalLM as Llama
@@ -266,6 +265,10 @@ def import_algorithm_packages(self) -> Dict[str, Any]:
266265
except ImportError:
267266
pruna_logger.error("TorchStructuredPruner: You need the GPU version of Pruna (timm, torchvision).")
268267
raise
268+
try:
269+
from gliner import GLiNER
270+
except ImportError: # onnxruntime doesn't have python3.10- wheels
271+
GLiNER = None # noqa: N806 # type: ignore[assignment]
269272
return dict(
270273
timm=timm,
271274
torchvision=torchvision,

0 commit comments

Comments
 (0)