Skip to content

Commit 4ed4b38

Browse files
committed
update code
1 parent 1f60d91 commit 4ed4b38

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

demo/demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def main():
2121
from pypandoc import convert_text
2222

2323
# build model
24-
model = build_model(args.ckpt_path, max_new_tokens=4096, max_time=120, use_gpu=(not args.cpu))
24+
model = build_model(args.ckpt_path, max_new_tokens=4096, max_time=60)
2525
if not args.cpu:
2626
model = model.cuda()
2727

struct_eqtable/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66

77

88
class StructTable(nn.Module):
9-
def __init__(self, model_path, max_new_tokens=2048, max_time=60, use_gpu=True):
9+
def __init__(self, model_path, max_new_tokens=2048, max_time=60):
1010
super().__init__()
1111
self.model_path = model_path
1212
self.max_new_tokens = max_new_tokens
1313
self.max_generate_time = max_time
14-
self.use_gpu = use_gpu
1514

1615
# init model and image processor from ckpt path
1716
self.init_image_processor(model_path)
@@ -37,9 +36,10 @@ def forward(self, image):
3736
images=image,
3837
return_tensors='pt',
3938
)
40-
if self.use_gpu:
41-
for k, v in image_tokens.items():
42-
image_tokens[k] = v.cuda()
39+
40+
device = next(self.parameters()).device
41+
for k, v in image_tokens.items():
42+
image_tokens[k] = v.to(device)
4343

4444
# generate text from image tokens
4545
model_output = self.model.generate(

0 commit comments

Comments
 (0)