Skip to content

Commit 33c7bdd

Browse files
author
Ubuntu
committed
pinning transformer version and added missin parameter
1 parent 483fe1c commit 33c7bdd

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ full = [
159159
"xformers==0.0.30",
160160
"stable-fast-pruna==1.0.7",
161161
]
162+
zipar = [
163+
"transformers==4.54.0",
164+
]
162165
dev = [
163166
"wget",
164167
"python-dotenv",

src/pruna/algorithms/zipar/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ def prepare_input_and_cache(self, input_ids, model_kwargs, generation_config, de
544544
# batch_size should account for both conditional/unconditional input; hence multiplied by 2
545545
batch_size=batch_size * 2,
546546
max_cache_len=self.num_image_tokens + seq_len,
547+
device=device,
547548
model_kwargs=model_kwargs,
548549
)
549550

0 commit comments

Comments
 (0)