|  | 
|  | 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +import torch | 
|  | 16 | +import torch.distributed as dist | 
|  | 17 | + | 
|  | 18 | +from ..utils import get_logger | 
|  | 19 | +from ._common import _BATCHED_INPUT_IDENTIFIERS | 
|  | 20 | +from .hooks import HookRegistry, ModelHook | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +logger = get_logger(__name__)  # pylint: disable=invalid-name | 
|  | 24 | + | 
|  | 25 | +_CFG_PARALLEL = "cfg_parallel" | 
|  | 26 | + | 
|  | 27 | + | 
|  | 28 | +class CFGParallelHook(ModelHook): | 
|  | 29 | +    def initialize_hook(self, module): | 
|  | 30 | +        if not dist.is_initialized(): | 
|  | 31 | +            raise RuntimeError("Distributed environment not initialized.") | 
|  | 32 | +        return module | 
|  | 33 | + | 
|  | 34 | +    def new_forward(self, module: torch.nn.Module, *args, **kwargs): | 
|  | 35 | +        if len(args) > 0: | 
|  | 36 | +            logger.warning( | 
|  | 37 | +                "CFGParallelHook is an example hook that does not work with batched positional arguments. Please use with caution." | 
|  | 38 | +            ) | 
|  | 39 | + | 
|  | 40 | +        world_size = dist.get_world_size() | 
|  | 41 | +        rank = dist.get_rank() | 
|  | 42 | + | 
|  | 43 | +        assert world_size == 2, "This is an example hook designed to only work with 2 processes." | 
|  | 44 | + | 
|  | 45 | +        for key in list(kwargs.keys()): | 
|  | 46 | +            if key not in _BATCHED_INPUT_IDENTIFIERS or kwargs[key] is None: | 
|  | 47 | +                continue | 
|  | 48 | +            kwargs[key] = torch.chunk(kwargs[key], world_size, dim=0)[rank].contiguous() | 
|  | 49 | + | 
|  | 50 | +        output = self.fn_ref.original_forward(*args, **kwargs) | 
|  | 51 | +        sample = output[0] | 
|  | 52 | +        sample_list = [torch.empty_like(sample) for _ in range(world_size)] | 
|  | 53 | +        dist.all_gather(sample_list, sample) | 
|  | 54 | +        sample = torch.cat(sample_list, dim=0).contiguous() | 
|  | 55 | + | 
|  | 56 | +        return_dict = kwargs.get("return_dict", False) | 
|  | 57 | +        if not return_dict: | 
|  | 58 | +            return (sample, *output[1:]) | 
|  | 59 | +        return output.__class__(sample, *output[1:]) | 
|  | 60 | + | 
|  | 61 | + | 
|  | 62 | +def apply_cfg_parallel(module: torch.nn.Module) -> None: | 
|  | 63 | +    registry = HookRegistry.check_if_exists_or_initialize(module) | 
|  | 64 | +    hook = CFGParallelHook() | 
|  | 65 | +    registry.register_hook(hook, _CFG_PARALLEL) | 
0 commit comments