- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
Description
Hi everyone,
I have the following scenario.
I have a machine with 2-GPUs and a running service that keep has two pipelines loaded to their corresponding devices. Also I have a list of LoRAs (say 10). On each request I split the batch into 2 parts (request also has the corresponding information about LoRA), load LoRAs and run the forward pass.
The problem I encounter is that whatever parallelization method I have tried (threading, multi-processing), the maximum I have achieved is pre-loading LoRAs on the cpu and then, moving them to GPU and only after that load_lora_weights from the state_dict.
Even if I attempt to achieve parallelization in by calling the chunk where I load in parallel in threads, the pipe starts to produce either a complete noise or a black image.
Where I would appreciate a lot the help is:
- To get an advice of elegantly loading multiple LoRAs at once into one pipe (all examples in the documentation indicate that one needs to do it 1 by 1)
- If I have 2 pipes on 2 different devices, how to parallelize the process of loading 1 LoRA to pipes on their corresponding devices.
def apply_multiple_loras_from_cache(pipes, adapter_names, lora_cache, lora_names, lora_strengths, devices):
    for device_index, pipe in enumerate(pipes):
        logger.info(f"Starting setup for device {devices[device_index]}")
        
        # Step 1: Unload LoRAs
        start = time.time()
        pipe.unload_lora_weights(reset_to_overwritten_params=False)
        logger.info(f"[Device {device_index}] Unload time: {time.time() - start:.3f}s")
        # Step 2: Parallelize CPU β GPU state_dict move
        def move_to_device(name):
            return name, {
                k: v.to(devices[device_index], non_blocking=True).to(pipe.dtype)
                for k, v in lora_cache[name]['state_dict'].items()
            }
        start = time.time()
        with ThreadPoolExecutor() as executor:
            future_to_name = {executor.submit(move_to_device, name): name for name in adapter_names}
            results = [future.result() for future in as_completed(future_to_name)]
        logger.info(f"[Device {device_index}] State dict move + dtype conversion time: {time.time() - start:.3f}s")
        # Step 3: Load adapters
        start = time.time()
        
        
        for adapter_name, state_dict in results:
            pipe.load_lora_weights(
                pretrained_model_name_or_path_or_dict=state_dict,
                adapter_name=adapter_name
            )
        logger.info(f"[Device {device_index}] Load adapter weights time: {time.time() - start:.3f}s")
        # Step 4: Set adapter weights
        start = time.time()
        pipe.set_adapters(lora_names, adapter_weights=lora_strengths)
        logger.info(f"[Device {device_index}] Set adapter weights time: {time.time() - start:.3f}s")
    torch.cuda.empty_cache()
    logger.info("All LoRAs applied and GPU cache cleared.")