Skip to content

Commit 45cccb1

Browse files
authored
Merge pull request #1 from mideind/separate-torch-dep
Make torch an optional dependency
2 parents 88c1e70 + 58a03a6 commit 45cccb1

File tree

5 files changed

+35
-11
lines changed

5 files changed

+35
-11
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
- name: Run tests
3636
run: |
3737
# --locked ensures we use exact versions from uv.lock without updating
38-
uv run --locked pytest tests/ -v
38+
uv run --locked --extra torch pytest tests/ -v
3939
env:
4040
HF_HUB_CACHE: ~/.cache/huggingface
4141
HF_TOKEN: ${{ secrets.HF_TOKEN }}

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@ A high-level Python interface for PoS tagging Icelandic text using the [IceBERT-
1010

1111
```bash
1212
# This package is currently not available on PyPI, so you need to install it directly from the source repository.
13+
14+
# Without PyTorch (lighter, but model inference won't work)
1315
pip install git+ssh://git@github.com/mideind/IceBERT-PoS.git
16+
17+
# With PyTorch support (required for model inference) - RECOMMENDED
18+
pip install "git+ssh://git@github.com/mideind/IceBERT-PoS.git[torch]"
1419
```
1520

16-
This will install the package with PyTorch.
21+
> **Note**: The `[torch]` extra is required for model inference, as PyTorch models need PyTorch to run. The default installation is only useful for development work that doesn't involve running the actual models.
1722
1823
## Features
1924

pyproject.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
[project]
22
name = "icebert-pos"
3-
dynamic = ["version"] # managed by setuptools-scm
3+
dynamic = ["version"] # managed by setuptools-scm
44
description = "A package for interacting with the IceBERT PoS model(s)."
55
readme = "README.md"
66
requires-python = ">=3.10,<4.0"
77
dependencies = [
88
"tokenizer>=3.4.4,<4.0",
9-
"transformers[torch]>=4.46.3,<5.0",
9+
"transformers>=4.46.3,<5.0",
1010
"rich>=13.0.0,<14.0",
1111
]
1212

13+
[project.optional-dependencies]
14+
torch = ["transformers[torch]>=4.46.3,<5.0"]
15+
1316
[project.scripts]
1417
icebert-pos = "icebert_pos.cli:main"
1518

1619
[dependency-groups]
17-
dev = [
18-
"pytest",
19-
"ruff",
20-
]
20+
dev = ["pytest", "ruff"]
2121

2222
[tool.setuptools_scm]
2323
# Use git tags for versioning

src/icebert_pos/interface.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# Copyright (C) Miðeind ehf.
22
# Simple POS tagging interface with classical tokenization
33

4+
from __future__ import annotations
5+
46
import logging
57
from dataclasses import dataclass
8+
from typing import TYPE_CHECKING
69

710
import tokenizer
8-
import torch
9-
from torch.nn.utils.rnn import pad_sequence
11+
12+
if TYPE_CHECKING:
13+
import torch
1014

1115
logger = logging.getLogger(__name__)
1216

@@ -154,6 +158,14 @@ def batch_sentences(
154158
Returns:
155159
Batched input tensors
156160
"""
161+
try:
162+
from torch.nn.utils.rnn import pad_sequence
163+
except ModuleNotFoundError as e:
164+
raise ImportError(
165+
"The 'torch' library is required for this function. Please install it using "
166+
"'pip install icebert-pos[torch]'."
167+
) from e
168+
157169
# Unzip the list of tuples into separate lists
158170
input_ids, attention_mask, word_mask = zip(*sentence_tensors, strict=True)
159171

uv.lock

Lines changed: 8 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)