Skip to content

Commit dac7793

Browse files
authored
Move model to device using from_pretrained (#423)
Using the `from_pretrained` function to move the model to device, trying to avoid the `meta` device bug. ``` self.model = TableTransformerForObjectDetection.from_pretrained( model, device_map=self.device ) ```
1 parent 7c05dec commit dac7793

File tree

6 files changed

+23
-5
lines changed

6 files changed

+23
-5
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 1.0.1
2+
3+
* fix: moving the table transformer model to device when loading the model instead of once the model is loaded.
4+
15
## 1.0.0
26

37
* feat: support for Python 3.10+; drop support for Python 3.9

requirements/base.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ torch
1010
timm
1111
# NOTE(alan): Pinned because this is when the most recent module we import appeared
1212
transformers>=4.25.1
13+
accelerate
1314
rapidfuzz
1415
pandas
1516
scipy

requirements/base.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#
55
# pip-compile requirements/base.in
66
#
7+
accelerate==1.7.0
8+
# via -r requirements/base.in
79
certifi==2025.4.26
810
# via requests
911
cffi==1.17.1
@@ -36,6 +38,7 @@ fsspec==2025.3.2
3638
huggingface-hub==0.31.2
3739
# via
3840
# -r requirements/base.in
41+
# accelerate
3942
# timm
4043
# tokenizers
4144
# transformers
@@ -58,6 +61,7 @@ networkx==3.4.2
5861
numpy==2.2.5
5962
# via
6063
# -r requirements/base.in
64+
# accelerate
6165
# contourpy
6266
# matplotlib
6367
# onnx
@@ -75,6 +79,7 @@ opencv-python==4.11.0.86
7579
# via -r requirements/base.in
7680
packaging==25.0
7781
# via
82+
# accelerate
7883
# huggingface-hub
7984
# matplotlib
8085
# onnxruntime
@@ -91,6 +96,8 @@ protobuf==6.31.0
9196
# via
9297
# onnx
9398
# onnxruntime
99+
psutil==7.0.0
100+
# via accelerate
94101
pycparser==2.22
95102
# via cffi
96103
pyparsing==3.2.3
@@ -107,6 +114,7 @@ pytz==2025.2
107114
# via pandas
108115
pyyaml==6.0.2
109116
# via
117+
# accelerate
110118
# huggingface-hub
111119
# timm
112120
# transformers
@@ -120,6 +128,7 @@ requests==2.32.3
120128
# transformers
121129
safetensors==0.5.3
122130
# via
131+
# accelerate
123132
# timm
124133
# transformers
125134
scipy==1.15.3
@@ -137,6 +146,7 @@ tokenizers==0.21.1
137146
torch==2.7.0
138147
# via
139148
# -r requirements/base.in
149+
# accelerate
140150
# timm
141151
# torchvision
142152
torchvision==0.22.0

requirements/dev.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ prompt-toolkit==3.0.51
264264
# ipython
265265
# jupyter-console
266266
psutil==7.0.0
267-
# via ipykernel
267+
# via
268+
# -c requirements/base.txt
269+
# ipykernel
268270
ptyprocess==0.7.0
269271
# via
270272
# pexpect
@@ -351,7 +353,7 @@ terminado==0.18.1
351353
# jupyter-server-terminals
352354
tinycss2==1.4.0
353355
# via bleach
354-
tornado==6.4.2
356+
tornado==6.5
355357
# via
356358
# ipykernel
357359
# jupyter-client
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.0" # pragma: no cover
1+
__version__ = "1.0.1" # pragma: no cover

unstructured_inference/models/tables.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def initialize(
6767
logger.info("Loading the table structure model ...")
6868
cached_current_verbosity = logging.get_verbosity()
6969
logging.set_verbosity_error()
70-
self.model = TableTransformerForObjectDetection.from_pretrained(model)
70+
self.model = TableTransformerForObjectDetection.from_pretrained(
71+
model, device_map=self.device
72+
)
7173
logging.set_verbosity(cached_current_verbosity)
7274
self.model.eval()
7375

@@ -77,7 +79,6 @@ def initialize(
7779
raise ImportError(
7880
"Review the parameters to initialize a UnstructuredTableTransformerModel obj",
7981
)
80-
self.model.to(device)
8182

8283
def get_structure(
8384
self,

0 commit comments

Comments
 (0)