@@ -617,20 +617,20 @@ def initialize_model_and_tokenizer(self):
617617 save_dir = info ["cache_dir" ]
618618 cache_dir = CACHE_DIR / save_dir
619619 cache_dir .mkdir (parents = True , exist_ok = True )
620-
620+
621621 self .device = torch .device ("cuda" )
622622 use_bf16 = torch .cuda .get_device_capability ()[0 ] >= 8
623623 dtype = torch .bfloat16 if use_bf16 else torch .float16
624-
624+
625625 quant_config = BitsAndBytesConfig (
626626 load_in_4bit = True ,
627627 bnb_4bit_quant_type = "nf4" ,
628628 bnb_4bit_compute_dtype = dtype
629629 )
630-
630+
631631 # Import the specific model class
632632 from transformers import Glm4vForConditionalGeneration
633-
633+
634634 model = Glm4vForConditionalGeneration .from_pretrained (
635635 model_id ,
636636 token = False ,
@@ -642,14 +642,14 @@ def initialize_model_and_tokenizer(self):
642642 device_map = "auto" ,
643643 attn_implementation = "sdpa"
644644 ).eval ()
645-
645+
646646 processor = AutoProcessor .from_pretrained (
647647 model_id ,
648648 use_fast = True ,
649649 trust_remote_code = True ,
650650 cache_dir = cache_dir
651651 )
652-
652+
653653 precision_str = "bfloat16" if use_bf16 else "float16"
654654 device_str = "CUDA" if self .device == "cuda" else "CPU"
655655 my_cprint (f"{ chosen_model } (Thinking Mode) loaded into memory on { device_str } ({ precision_str } )" , "green" )
@@ -670,11 +670,11 @@ def process_single_image(self, raw_image):
670670
671671 generated_tokens = outputs [0 ][len (inputs .input_ids [0 ]):]
672672 response = self .processor .decode (generated_tokens , skip_special_tokens = True , clean_up_tokenization_spaces = False ).strip ()
673-
673+
674674 # Extract content between <answer> and </answer> tags
675675 if '<answer>' in response and '</answer>' in response :
676676 start_idx = response .find ('<answer>' ) + len ('<answer>' )
677677 end_idx = response .find ('</answer>' )
678678 response = response [start_idx :end_idx ].strip ()
679-
679+
680680 return ' ' .join (line .strip () for line in response .split ('\n ' ) if line .strip ())
0 commit comments