Skip to content

Commit 0e2c6e7

Browse files
authored
prevant glmv4.1 from loading twice
1 parent c10b3fb commit 0e2c6e7

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/module_process_images.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -617,20 +617,20 @@ def initialize_model_and_tokenizer(self):
617617
save_dir = info["cache_dir"]
618618
cache_dir = CACHE_DIR / save_dir
619619
cache_dir.mkdir(parents=True, exist_ok=True)
620-
620+
621621
self.device = torch.device("cuda")
622622
use_bf16 = torch.cuda.get_device_capability()[0] >= 8
623623
dtype = torch.bfloat16 if use_bf16 else torch.float16
624-
624+
625625
quant_config = BitsAndBytesConfig(
626626
load_in_4bit=True,
627627
bnb_4bit_quant_type="nf4",
628628
bnb_4bit_compute_dtype=dtype
629629
)
630-
630+
631631
# Import the specific model class
632632
from transformers import Glm4vForConditionalGeneration
633-
633+
634634
model = Glm4vForConditionalGeneration.from_pretrained(
635635
model_id,
636636
token=False,
@@ -642,14 +642,14 @@ def initialize_model_and_tokenizer(self):
642642
device_map="auto",
643643
attn_implementation="sdpa"
644644
).eval()
645-
645+
646646
processor = AutoProcessor.from_pretrained(
647647
model_id,
648648
use_fast=True,
649649
trust_remote_code=True,
650650
cache_dir=cache_dir
651651
)
652-
652+
653653
precision_str = "bfloat16" if use_bf16 else "float16"
654654
device_str = "CUDA" if self.device == "cuda" else "CPU"
655655
my_cprint(f"{chosen_model} (Thinking Mode) loaded into memory on {device_str} ({precision_str})", "green")
@@ -670,11 +670,11 @@ def process_single_image(self, raw_image):
670670

671671
generated_tokens = outputs[0][len(inputs.input_ids[0]):]
672672
response = self.processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()
673-
673+
674674
# Extract content between <answer> and </answer> tags
675675
if '<answer>' in response and '</answer>' in response:
676676
start_idx = response.find('<answer>') + len('<answer>')
677677
end_idx = response.find('</answer>')
678678
response = response[start_idx:end_idx].strip()
679-
679+
680680
return ' '.join(line.strip() for line in response.split('\n') if line.strip())

0 commit comments

Comments
 (0)